--- /dev/null
+import dns
+import os
+import socket
+import struct
+import sys
+import threading
+import time
+
+from recursortests import RecursorTest
+
+class RPZServer(object):
+
+ def __init__(self, port):
+ self._currentSerial = 0
+ self._targetSerial = 1
+ self._serverPort = port
+ listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
+ listener.setDaemon(True)
+ listener.start()
+
+ def getCurrentSerial(self):
+ return self._currentSerial
+
+ def moveToSerial(self, newSerial):
+ if newSerial == self._currentSerial:
+ return False
+
+ if newSerial != self._currentSerial + 1:
+ raise AssertionError("Asking the RPZ server to server serial %d, already serving %d" % (newSerial, self._currentSerial))
+ self._targetSerial = newSerial
+ return True
+
+ def _getAnswer(self, message):
+
+ response = dns.message.make_response(message)
+ records = []
+
+ if message.question[0].rdtype == dns.rdatatype.AXFR:
+ if self._currentSerial != 0:
+ print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+ return (None, self._currentSerial)
+
+ newSerial = self._targetSerial
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+
+ elif message.question[0].rdtype == dns.rdatatype.IXFR:
+ oldSerial = message.authority[0][0].serial
+
+ if oldSerial != self._currentSerial:
+ print('Received an IXFR query with an unexpected serial %d, expected %d' % (oldSerial, self._currentSerial))
+ return (None, self._currentSerial)
+
+ newSerial = self._targetSerial
+ if newSerial == 2:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+ # no deletion
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+ elif newSerial == 3:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ # no addition
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+ elif newSerial == 4:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+ dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('c.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+
+ response.answer = records
+ return (newSerial, response)
+
+ def _connectionHandler(self, conn):
+ data = None
+ while True:
+ data = conn.recv(2)
+ if not data:
+ break
+ (datalen,) = struct.unpack("!H", data)
+ data = conn.recv(datalen)
+ if not data:
+ break
+
+ message = dns.message.from_wire(data)
+ if len(message.question) != 1:
+ print('Invalid RPZ query, qdcount is %d' % (len(message.question)))
+ break
+ if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+ print('Invalid RPZ query, qtype is %d' % (message.question.rdtype))
+ break
+ (serial, answer) = self._getAnswer(message)
+ if not answer:
+ print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype))
+ break
+
+ wire = answer.to_wire()
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ self._currentSerial = serial
+ break
+
+ conn.close()
+
+ def _listener(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(("127.0.0.1", self._serverPort))
+ except socket.error as e:
+ print("Error binding in the RPZ listener: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+ while True:
+ try:
+ (conn, _) = sock.accept()
+ thread = threading.Thread(name='RPZ Connection Handler',
+ target=self._connectionHandler,
+ args=[conn])
+ thread.setDaemon(True)
+ thread.start()
+
+ except socket.error as e:
+ print('Error in RPZ socket: %s' % str(e))
+ sock.close()
+
+rpzServerPort = 4244
+rpzServer = RPZServer(rpzServerPort)
+
+class RPZRecursorTest(RecursorTest):
+ """
+ This test makes sure that we correctly update RPZ zones via AXFR then IXFR
+ """
+
+ global rpzServerPort
+ _lua_config_file = """
+ rpzMaster('127.0.0.1:%d', 'zone.rpz.', { refresh=1 })
+ """ % (rpzServerPort)
+ _confdir = 'RPZ'
+ _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+
+
+ @classmethod
+ def generateRecursorConfig(cls, confdir):
+ authzonepath = os.path.join(confdir, 'example.zone')
+ with open(authzonepath, 'w') as authzone:
+ authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+ super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
+
+ @classmethod
+ def setUpClass(cls):
+
+ cls.setUpSockets()
+ cls.startResponders()
+
+ confdir = os.path.join('configs', cls._confdir)
+ cls.createConfigDir(confdir)
+
+ cls.generateRecursorConfig(confdir)
+ cls.startRecursor(confdir, cls._recursorPort)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.tearDownRecursor()
+
+ def checkBlocked(self, name, shouldBeBlocked=True):
+ query = dns.message.make_query(name, 'A', want_dnssec=True)
+ query.flags |= dns.flags.CD
+ res = self.sendUDPQuery(query)
+ if shouldBeBlocked:
+ expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
+ else:
+ expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+
+ self.assertRRsetInAnswer(res, expected)
+
+ def checkNotBlocked(self, name):
+ self.checkBlocked(name, False)
+
+ def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+ global rpzServer
+
+ rpzServer.moveToSerial(serial)
+
+ attempts = 0
+ while attempts < timeout:
+ currentSerial = rpzServer.getCurrentSerial()
+ if currentSerial > serial:
+ raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+ if currentSerial == serial:
+ return
+
+ attempts = attempts + 1
+ time.sleep(1)
+
+ raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+ def testRPZ(self):
+ # first zone, only a should be blocked
+ self.waitUntilCorrectSerialIsLoaded(1)
+ self.checkBlocked('a.example.')
+ self.checkNotBlocked('b.example.')
+ self.checkNotBlocked('c.example.')
+
+ # second zone, a and b should be blocked
+ self.waitUntilCorrectSerialIsLoaded(2)
+ self.checkBlocked('a.example.')
+ self.checkBlocked('b.example.')
+ self.checkNotBlocked('c.example.')
+
+ # third zone, only b should be blocked
+ self.waitUntilCorrectSerialIsLoaded(3)
+ self.checkNotBlocked('a.example.')
+ self.checkBlocked('b.example.')
+ self.checkNotBlocked('c.example.')
+
+ # fourth zone, only c should be blocked
+ self.waitUntilCorrectSerialIsLoaded(4)
+ self.checkNotBlocked('a.example.')
+ self.checkNotBlocked('b.example.')
+ self.checkBlocked('c.example.')