Update the README
[public/dnssec-swede-utility.git] / swede
diff --git a/swede b/swede
index d2f5d0e2744fd11586276c60c98af4c2562635d5..260952819c2ffa6fd9b13d565d0ba76fff2ee3e9 100755 (executable)
--- a/swede
+++ b/swede
@@ -17,6 +17,7 @@
 
 import sys
 import os
+import socket
 import unbound
 import re
 from M2Crypto import X509, SSL
@@ -25,7 +26,7 @@ from hashlib import sha256, sha512
 from ipaddr import IPv4Address, IPv6Address
 
 
-def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, selector=0, mtype=1):
+def genTLSA(hostname, protocol, port, certificate, output='generic', usage=1, selector=0, mtype=1):
        """This function generates a TLSARecord object using the data passed in the parameters,
        it then validates the record and returns the RR as a string.
        """
@@ -53,8 +54,8 @@ def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, sele
 
        record.isValid(raiseException=True)
 
-       if output == 'draft':
-               return record.getRecord(draft=True)
+       if output == 'generic':
+               return record.getRecord(generic=True)
        return record.getRecord()
 
 def getA(hostname, secure=True):
@@ -170,7 +171,7 @@ def getHash(certificate, mtype):
 def getTLSA(hostname, port=443, protocol='tcp', secure=True):
        """
        This function tries to do a secure lookup of the TLSA record.
-       At the moment it requests the TYPE65468 record and parses it into a 'valid' TLSA record
+       At the moment it requests the TYPE52 record and parses it into a 'valid' TLSA record
        It returns a list of TLSARecord objects
        """
        if hostname[-1] != '.':
@@ -180,9 +181,9 @@ def getTLSA(hostname, port=443, protocol='tcp', secure=True):
                raise Exception('Error: unknown protocol: %s. Should be one of tcp, udp or sctp' % protocol)
        try:
                if port == '*':
-                       records = getRecords('*._%s.%s' % (protocol.lower(), hostname), rrtype=65468, secure=secure)
+                       records = getRecords('*._%s.%s' % (protocol.lower(), hostname), rrtype=52, secure=secure)
                else:
-                       records = getRecords('_%s._%s.%s' % (port, protocol.lower(), hostname), rrtype=65468, secure=secure)
+                       records = getRecords('_%s._%s.%s' % (port, protocol.lower(), hostname), rrtype=52, secure=secure)
        except InsecureLookupException, e:
                print str(e)
                sys.exit(1)
@@ -242,7 +243,7 @@ class TLSARecord:
                cert should be a hexidecimal string representing the certificate to be matched field
                """
                try:
-                       self.rrtype = 65468 # TLSA provisional
+                       self.rrtype = 52    # TLSA per https://www.iana.org/assignments/dns-parameters
                        self.rrclass = 1    # IN
                        self.name = str(name)
                        self.usage = int(usage)
@@ -252,10 +253,10 @@ class TLSARecord:
                except:
                        raise Exception('Invalid value passed, unable to create a TLSARecord')
 
-       def getRecord(self, draft=False):
-               """Returns the RR string of this TLSARecord, either in rfc (default) or draft format"""
-               if draft:
-                       return '%s IN TYPE65468 \# %s %s%s%s%s' % (self.name, (len(self.cert)/2)+3 , self._toHex(self.usage), self._toHex(self.selector), self._toHex(self.mtype), self.cert)
+       def getRecord(self, generic=False):
+               """Returns the RR string of this TLSARecord, either in rfc (default) or generic format"""
+               if generic:
+                       return '%s IN TYPE52 \# %s %s%s%s%s' % (self.name, (len(self.cert)/2)+3 , self._toHex(self.usage), self._toHex(self.selector), self._toHex(self.mtype), self.cert)
                return '%s IN TLSA %s %s %s %s' % (self.name, self.usage, self.selector, self.mtype, self.cert)
 
        def _toHex(self, val):
@@ -381,7 +382,7 @@ if __name__ == '__main__':
        parser_create.add_argument('--port', '-p', action='store', type=int, default=443, help='The port where running TLS is located (default: %(default)s).')
        parser_create.add_argument('--protocol', action='store', choices=['tcp','udp','sctp'], default='tcp', help='The protocol the TLS service is using (default: %(default)s).')
        parser_create.add_argument('--certificate', '-c', help='The certificate used for the host. If certificate is empty, the certificate will be downloaded from the server')
-       parser_create.add_argument('--output', '-o', action='store', default='draft', choices=['draft','rfc','both'], help='The type of output. Draft (private RRtype, 65468), RFC (TLSA) or both (default: %(default)s).')
+       parser_create.add_argument('--output', '-o', action='store', default='generic', choices=['generic','rfc','both'], help='The type of output. Generic (RFC 3597, TYPE52), RFC (TLSA) or both (default: %(default)s).')
 
        # Usage of the certificate
        parser_create.add_argument('--usage', '-u', action='store', type=int, default=1, choices=[0,1,2,3], help='The Usage of the Certificate for Association. \'0\' for CA, \'1\' for End Entity, \'2\' for trust-anchor, \'3\' for ONLY End-Entity match (default: %(default)s).')
@@ -441,7 +442,7 @@ if __name__ == '__main__':
 
                        if not args.quiet:
                                print 'Attempting to verify the record with the TLS service...'
-                       addresses = getA(args.host, secure=secure)
+                       addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
                        for address in addresses:
                                if not args.quiet:
                                        print 'Got the following IP: %s' % str(address)
@@ -456,7 +457,12 @@ if __name__ == '__main__':
                                        sys.exit(1)
                                # Don't error when the verification fails in the SSL handshake
                                ctx.set_verify(SSL.verify_none, depth=9)
-                               connection = SSL.Connection(ctx)
+                               if isinstance(address, AAAARecord):
+                                       sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               else:
+                                       sock = None
+                               connection = SSL.Connection(ctx, sock=sock)
                                try:
                                        connection.connect((str(address), int(args.port)))
                                except SSL.Checker.WrongHost, e:
@@ -565,7 +571,7 @@ if __name__ == '__main__':
                                                sys.stdout.write('Port %s not numerical or within correct range (1 <= port <= 65535), try again (hit enter for default 443): ' % user_input)
                        # Get the A records for the host
                        try:
-                               addresses = getA(args.host, secure=secure)
+                               addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
                        except InsecureLookupException, e:
                                print >> sys.stderr, str(e)
                                sys.exit(1)
@@ -575,7 +581,12 @@ if __name__ == '__main__':
                                # We do the certificate handling here, as M2Crypto keeps segfaulting when try to do stuff with the cert if we don't
                                ctx = SSL.Context()
                                ctx.set_verify(SSL.verify_none, depth=9)
-                               connection = SSL.Connection(ctx)
+                               if isinstance(address, AAAARecord):
+                                       sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               else:
+                                       sock = None
+                               connection = SSL.Connection(ctx, sock=sock)
                                try:
                                        connection.connect((str(address), int(connection_port)))
                                except SSL.Checker.WrongHost: