added gentoo-ebuild
[public/dnssec-swede-utility.git] / swede
diff --git a/swede b/swede
index 4be39b51e1c7d060ffe765f8e3def42d785ece0e..cd212a4c38b5595c387f9b4e7f9d89e25898f34c 100755 (executable)
--- a/swede
+++ b/swede
@@ -1,6 +1,6 @@
-#!/usr/bin/python
+#!/usr/bin/python2
 
-# swede - A tool to create DANE/TLSA (draft 14) records.
+# swede - A tool to create DANE/TLSA records.
 # This tool is really simple and not foolproof, it doesn't check the CN in the
 # Subject field of the certificate. It also doesn't check if the supplied
 # certificate is a CA certificate if usage 1 is specified (or any other
@@ -17,6 +17,8 @@
 
 import sys
 import os
+import os.path
+import socket
 import unbound
 import re
 from M2Crypto import X509, SSL
@@ -24,7 +26,11 @@ from binascii import a2b_hex, b2a_hex
 from hashlib import sha256, sha512
 from ipaddr import IPv4Address, IPv6Address
 
-def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, selector=0, mtype=1):
+check_ipv4=True
+check_ipv6=True
+
+
+def genTLSA(hostname, protocol, port, certificate, output='generic', usage=1, selector=0, mtype=1):
        """This function generates a TLSARecord object using the data passed in the parameters,
        it then validates the record and returns the RR as a string.
        """
@@ -52,21 +58,37 @@ def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, sele
 
        record.isValid(raiseException=True)
 
-       if output == 'draft':
-               return record.getRecord(draft=True)
+       if output == 'generic':
+               return record.getRecord(generic=True)
        return record.getRecord()
 
 def getA(hostname, secure=True):
+       if not check_ipv4: return []
        """Gets a list of A records for hostname, returns a list of ARecords"""
-       records = getRecords(hostname, rrtype='A', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='A', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(ARecord(hostname, str(IPv4Address(int(b2a_hex(record),16)))))
        return ret
 
 def getAAAA(hostname, secure=True):
+       if not check_ipv6: return []
        """Gets a list of A records for hostname, returns a list of AAAARecords"""
-       records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(AAAARecord(hostname, str(IPv6Address(int(b2a_hex(record),16)))))
@@ -111,7 +133,22 @@ def getRecords(hostname, rrtype='A', secure=True):
        """Do a lookup of a name and a rrtype, returns a list of binary coded strings. Only queries for rr_class IN."""
        global resolvconf
        ctx = unbound.ub_ctx()
-       ctx.add_ta_file('root.key')
+       if os.path.exists("root.key"):
+               ctx.add_ta_file('root.key')
+       elif os.path.exists("/etc/swede/root.key"):
+               ctx.add_ta_file('/etc/swede/root.key')
+       else:
+               print "Cannot find root.key, please move it to /etc/swede"
+               sys.exit()
+
+       if os.path.exists("dlv.isc.org.key"):
+               ctx.set_option("dlv-anchor-file:", "dlv.isc.org.key")
+       elif os.path.exists("/etc/swede/dlv.isc.org.key"):
+               ctx.set_option("dlv-anchor-file:", "/etc/swede/dlv.isc.org.key")
+       else:
+               print "Cannot find dlv.isc.org.key, please move it to /etc/swede"
+               sys.exit()
+
        # Use the local cache
        if resolvconf and os.path.isfile(resolvconf):
                ctx.resolvconf(resolvconf)
@@ -135,7 +172,7 @@ def getRecords(hostname, rrtype='A', secure=True):
                # If we are here the data was either secure or insecure data is accepted
                return result.data.raw
        else:
-               raise Exception('Error: Unsuccesful lookup or no data returned.')
+               raise DNSLookupError('Unsuccesful lookup or no data returned for rrtype %s.' % rrtype)
 
 def getHash(certificate, mtype):
        """Hashes the certificate based on the mtype.
@@ -154,7 +191,7 @@ def getHash(certificate, mtype):
 def getTLSA(hostname, port=443, protocol='tcp', secure=True):
        """
        This function tries to do a secure lookup of the TLSA record.
-       At the moment it requests the TYPE65468 record and parses it into a 'valid' TLSA record
+       At the moment it requests the TYPE52 record and parses it into a 'valid' TLSA record
        It returns a list of TLSARecord objects
        """
        if hostname[-1] != '.':
@@ -164,12 +201,15 @@ def getTLSA(hostname, port=443, protocol='tcp', secure=True):
                raise Exception('Error: unknown protocol: %s. Should be one of tcp, udp or sctp' % protocol)
        try:
                if port == '*':
-                       records = getRecords('*._%s.%s' % (protocol.lower(), hostname), rrtype=65468, secure=secure)
+                       records = getRecords('*._%s.%s' % (protocol.lower(), hostname), rrtype=52, secure=secure)
                else:
-                       records = getRecords('_%s._%s.%s' % (port, protocol.lower(), hostname), rrtype=65468, secure=secure)
+                       records = getRecords('_%s._%s.%s' % (port, protocol.lower(), hostname), rrtype=52, secure=secure)
        except InsecureLookupException, e:
                print str(e)
                sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                hexdata = b2a_hex(record)
@@ -214,6 +254,28 @@ def verifyCertMatch(record, cert):
        else:
                return False
 
+def verifyCertNameWithHostName(cert, hostname, with_msg=False):
+       """Verify the name on the certificate with a hostname, we need this because we get the cert based on IP address and thusly cannot rely on M2Crypto to verify this"""
+       if not isinstance(cert, X509.X509):
+               return
+       if not isinstance(hostname, str):
+               return
+
+       if hostname[-1] == '.':
+               hostname = hostname[0:-1]
+
+       # Ugly string comparison to see if the name on the ee-cert matches with the name provided on the commandline
+       try:
+               altnames_on_cert = cert.get_ext('subjectAltName').get_value()
+       except:
+               altnames_on_cert = ''
+       if hostname in (str(cert.get_subject()) + altnames_on_cert):
+               return True
+       else:
+               if with_msg:
+                       print 'WARNING: Name on the certificate (Subject: %s, SubjectAltName: %s) doesn\'t match requested hostname (%s).' % (str(cert.get_subject()), altnames_on_cert, hostname)
+               return False
+
 class TLSARecord:
        """When instanciated, this class contains all the fields of a TLSA record.
        """
@@ -223,7 +285,7 @@ class TLSARecord:
                cert should be a hexidecimal string representing the certificate to be matched field
                """
                try:
-                       self.rrtype = 65468 # TLSA provisional
+                       self.rrtype = 52    # TLSA per https://www.iana.org/assignments/dns-parameters
                        self.rrclass = 1    # IN
                        self.name = str(name)
                        self.usage = int(usage)
@@ -233,10 +295,10 @@ class TLSARecord:
                except:
                        raise Exception('Invalid value passed, unable to create a TLSARecord')
 
-       def getRecord(self, draft=False):
-               """Returns the RR string of this TLSARecord, either in rfc (default) or draft format"""
-               if draft:
-                       return '%s IN TYPE65468 \# %s %s%s%s%s' % (self.name, (len(self.cert)/2)+3 , self._toHex(self.usage), self._toHex(self.selector), self._toHex(self.mtype), self.cert)
+       def getRecord(self, generic=False):
+               """Returns the RR string of this TLSARecord, either in rfc (default) or generic format"""
+               if generic:
+                       return '%s IN TYPE52 \# %s %s%s%s%s' % (self.name, (len(self.cert)/2)+3 , self._toHex(self.usage), self._toHex(self.selector), self._toHex(self.mtype), self.cert)
                return '%s IN TLSA %s %s %s %s' % (self.name, self.usage, self.selector, self.mtype, self.cert)
 
        def _toHex(self, val):
@@ -252,8 +314,8 @@ class TLSARecord:
                except:
                        if self.getPort() != '*':
                                err.append('Port %s not a number' % self.getPort())
-               if not self.usage in [0,1,2]:
-                       err.append('Usage: invalid (%s is not one of 0, 1 or 2)' % self.usage)
+               if not self.usage in [0,1,2,3]:
+                       err.append('Usage: invalid (%s is not one of 0, 1, 2 or 3)' % self.usage)
                if not self.selector in [0,1]:
                        err.append('Selector: invalid (%s is not one of 0 or 1)' % self.selector)
                if not self.mtype in [0,1,2]:
@@ -277,7 +339,7 @@ class TLSARecord:
 
        def isNameValid(self):
                """Check if the name if in the correct format"""
-               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([a-z0-9]*\.){2,}$', self.name):
+               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([-a-z0-9]*\.){2,}$', self.name):
                        return False
                return True
 
@@ -313,6 +375,7 @@ class AAAARecord:
        """An object representing an AAAA Record (IPv6 address)"""
        def __init__(self, hostname, address):
                self.rrtype = 28
+               self.hostname = hostname
                self.address = address
 
        def __str__(self):
@@ -327,33 +390,30 @@ class AAAARecord:
 
 # Exceptions
 class RecordValidityException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
 
 class InsecureLookupException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
+
+class DNSLookupError(Exception):
+       pass
 
 if __name__ == '__main__':
        import argparse
        # create the parser
-       parser = argparse.ArgumentParser(description='Create and verify DANE records.', epilog='This tool has a few limitations: it only IPv4 for SSL connections.')
+       parser = argparse.ArgumentParser(description='Create and verify DANE records.', epilog='This tool has a few limitations')
 
        subparsers = parser.add_subparsers(title='Functions', help='Available functions, see %(prog)s function -h for function-specific help')
-       parser_verify = subparsers.add_parser('verify', help='Verify a TLSA record, exit 0 when all TLSA records are matched, exit 2 when a record does not match the received certificate, exit 1 on error.', epilog='Caveat: For TLSA validation, this program chases through the certificate chain offered by the server, not it\'s local certificates.')
+       parser_verify = subparsers.add_parser('verify', help='Verify a TLSA record, exit 0 when all TLSA records are matched, exit 2 when a record does not match the received certificate, exit 1 on error.', epilog='Caveat: For TLSA validation, this program chases through the certificate chain offered by the server, not its local certificates.')
        parser_verify.set_defaults(function='verify')
        parser_create = subparsers.add_parser('create', help='Create a TLSA record')
        parser_create.set_defaults(function='create')
 
-       #parser.add_argument('-4', dest='ipv4', action='store_true',help='use ipv4 networking only')
-       #parser.add_argument('-6', dest='ipv6', action='store_true',help='use ipv6 networking only')
+       parser.add_argument('-4', dest='ipv4', action='store_true',help='use ipv4 networking only')
+       parser.add_argument('-6', dest='ipv6', action='store_true',help='use ipv6 networking only')
        parser.add_argument('--insecure', action='store_true', default=False, help='Allow use of non-dnssec secured answers')
        parser.add_argument('--resolvconf', metavar='/PATH/TO/RESOLV.CONF', action='store', default='', help='Use a recursive resolver from resolv.conf')
-       parser.add_argument('-v', '--version', action='version', version='%(prog)s v0.1', help='show version and exit')
+       parser.add_argument('-v', '--version', action='version', version='%(prog)s v0.2', help='show version and exit')
        parser.add_argument('host', metavar="hostname")
 
        parser_verify.add_argument('--port', '-p', action='store', default='443', help='The port, or \'*\' where running TLS is located (default: %(default)s).')
@@ -365,14 +425,23 @@ if __name__ == '__main__':
        parser_create.add_argument('--port', '-p', action='store', type=int, default=443, help='The port where running TLS is located (default: %(default)s).')
        parser_create.add_argument('--protocol', action='store', choices=['tcp','udp','sctp'], default='tcp', help='The protocol the TLS service is using (default: %(default)s).')
        parser_create.add_argument('--certificate', '-c', help='The certificate used for the host. If certificate is empty, the certificate will be downloaded from the server')
-       parser_create.add_argument('--output', '-o', action='store', default='draft', choices=['draft','rfc','both'], help='The type of output. Draft (private RRtype, 65468), RFC (TLSA) or both (default: %(default)s).')
+       parser_create.add_argument('--output', '-o', action='store', default='generic', choices=['generic','rfc','both'], help='The type of output. Generic (RFC 3597, TYPE52), RFC (TLSA) or both (default: %(default)s).')
 
        # Usage of the certificate
-       parser_create.add_argument('--usage', '-u', action='store', type=int, default=1, choices=[0,1,2], help='The Usage of the Certificate for Association. \'0\' for CA, \'1\' for End Entity, \'2\' for trust-anchor (default: %(default)s).')
+       parser_create.add_argument('--usage', '-u', action='store', type=int, default=1, choices=[0,1,2,3], help='The Usage of the Certificate for Association. \'0\' for CA, \'1\' for End Entity, \'2\' for trust-anchor, \'3\' for ONLY End-Entity match (default: %(default)s).')
        parser_create.add_argument('--selector', '-s', action='store', type=int, default=0, choices=[0,1], help='The Selector for the Certificate for Association. \'0\' for Full Certificate, \'1\' for SubjectPublicKeyInfo (default: %(default)s).')
        parser_create.add_argument('--mtype', '-m', action='store', type=int, default=1, choices=[0,1,2], help='The Matching Type of the Certificate for Association. \'0\' for Exact match, \'1\' for SHA-256 hash, \'2\' for SHA-512 (default: %(default)s).')
 
        args = parser.parse_args()
+       import pprint
+       pprint.pprint(args)
+       if args.ipv4 == True and args.ipv6 == True: 
+               print "Cannot have only ipv4 and only ipv6 at the same time"
+               sys.exit()
+       elif args.ipv4 == True:
+               check_ipv6 = False
+       elif args.ipv6 == True:
+               check_ipv4 = False
 
        if args.host[-1] != '.':
                args.host += '.'
@@ -400,16 +469,16 @@ if __name__ == '__main__':
                        # First, check if the first three fields have correct values.
                        if not args.quiet:
                                print 'Received the following record for name %s:' % record.name
-                               print '\tUsage:\t\t\t\t%d (%s)' % (record.usage, {0:'CA Constraint', 1:'End-Entity Constraint', 2:'Trust Anchor'}[record.usage])
-                               print '\tSelector:\t\t\t%d (%s)' % (record.selector, {0:'Certificate', 1:'SubjectPublicKeyInfo'}[record.selector])
-                               print '\tMatching Type:\t\t\t%d (%s)' % (record.mtype, {0:'Full Certificate', 1:'SHA-256', 2:'SHA-512'}[record.mtype])
+                               print '\tUsage:\t\t\t\t%d (%s)' % (record.usage, {0:'CA Constraint', 1:'End-Entity Constraint + chain to CA', 2:'Trust Anchor', 3:'End-Entity'}.get(record.usage, 'INVALID'))
+                               print '\tSelector:\t\t\t%d (%s)' % (record.selector, {0:'Certificate', 1:'SubjectPublicKeyInfo'}.get(record.selector, 'INVALID'))
+                               print '\tMatching Type:\t\t\t%d (%s)' % (record.mtype, {0:'Full Certificate', 1:'SHA-256', 2:'SHA-512'}.get(record.mtype, 'INVALID'))
                                print '\tCertificate for Association:\t%s' % record.cert
 
                        try:
                                record.isValid(raiseException=True)
                        except RecordValidityException, e:
-                               print sys.stderr, 'Error: %s' % str(e)
-                               sys.exit(1)
+                               print >> sys.stderr, 'Error: %s' % str(e)
+                               continue
                        else:
                                if not args.quiet:
                                        print 'This record is valid (well-formed).'
@@ -425,7 +494,13 @@ if __name__ == '__main__':
 
                        if not args.quiet:
                                print 'Attempting to verify the record with the TLS service...'
-                       addresses = getA(args.host, secure=secure)
+                       if check_ipv4 and check_ipv6:
+                               addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
+                       elif check_ipv4:
+                               addresses = getA(args.host, secure=secure) 
+                       else:
+                               addresses = getAAAA(args.host, secure=secure)
+                               
                        for address in addresses:
                                if not args.quiet:
                                        print 'Got the following IP: %s' % str(address)
@@ -440,21 +515,36 @@ if __name__ == '__main__':
                                        sys.exit(1)
                                # Don't error when the verification fails in the SSL handshake
                                ctx.set_verify(SSL.verify_none, depth=9)
-                               connection = SSL.Connection(ctx)
+                               if  check_ipv6 and isinstance(address, AAAARecord):
+                                       sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               elif  check_ipv4 and isinstance(address, ARecord):
+                                       sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               else:
+                                       sock = None
+                               connection = SSL.Connection(ctx, sock=sock)
                                try:
                                        connection.connect((str(address), int(args.port)))
                                except SSL.Checker.WrongHost, e:
                                        # The name on the remote cert doesn't match the hostname because we connect on IP, not hostname (as we want secure lookup)
                                        pass
+                               except socket.error, e:
+                                       print 'Cannot connect to %s: %s' % (address, str(e))
+                                       continue
                                chain = connection.get_peer_cert_chain()
                                verify_result = connection.get_verify_result()
 
                                # Good, now let's verify
+                               if not verifyCertNameWithHostName(cert=chain[0], hostname=str(args.host), with_msg=True):
+                                       # The name on the cert doesn't match the hostname... we don't verify the TLSA record
+                                       print 'Not checking the TLSA record.'
+                                       continue
                                if record.usage == 1: # End-host cert
                                        cert = chain[0]
                                        if verifyCertMatch(record, cert):
                                                if verify_result == 0: # The cert chains to a valid CA cert according to the system-certificates
-                                                       print 'SUCCES (Usage 1): Certificate offered by the server matches the one mentioned in the TLSA record and chains to a valid CA certificate'
+                                                       print 'SUCCESS (Usage 1): Certificate offered by the server matches the one mentioned in the TLSA record and chains to a valid CA certificate'
                                                else:
                                                        print 'FAIL (Usage 1): Certificate offered by the server matches the one mentioned in the TLSA record but the following error was raised during PKIX validation: %s' % getVerificationErrorReason(verify_result)
                                                        if pre_exit == 0: pre_exit = 2
@@ -466,6 +556,7 @@ if __name__ == '__main__':
                                elif record.usage == 0: # CA constraint
                                        matched = False
                                        # Remove the first (= End-Entity cert) from the chain
+                                       chain = chain[1:]
                                        for cert in chain:
                                                if verifyCertMatch(record, cert):
                                                        matched = True
@@ -473,7 +564,7 @@ if __name__ == '__main__':
                                        if matched:
                                                if cert.check_ca():
                                                        if verify_result == 0:
-                                                               print 'SUCCES (Usage 0): A certificate in the certificate chain offered by the server matches the one mentioned in the TLSA record and is a CA certificate'
+                                                               print 'SUCCESS (Usage 0): A certificate in the certificate chain offered by the server matches the one mentioned in the TLSA record and is a CA certificate'
                                                        else:
                                                                print 'FAIL (Usage 0): A certificate in the certificate chain offered by the server matches the one mentioned in the TLSA record and is a CA certificate, but the following error was raised during PKIX validation:' % getVerificationErrorReason(verify_result)
                                                                if pre_exit == 0: pre_exit = 2
@@ -485,19 +576,34 @@ if __name__ == '__main__':
                                                print 'FAIL (Usage 0): No certificate in the certificate chain offered by the server matches the TLSA record'
                                                if pre_exit == 0: pre_exit = 2
 
-                               elif record.usage == 2: # Usage 2, ANY cert in the chain must match (aka 'pick any')
+                               elif record.usage == 2: # Usage 2, use the cert in the record as trust anchor
+                                       #FIXME: doesnt comply to the spec
                                        matched = False
+                                       previous_issuer = None
                                        for cert in chain:
+                                               if previous_issuer:
+                                                       if not str(previous_issuer) == str(cert.get_subject()): # The chain cannot be valid
+                                                               print "FAIL: Certificates don't chain"
+                                                               break
+                                                       previous_issuer = cert.get_issuer()
                                                if verifyCertMatch(record, cert):
                                                        matched = True
                                                        continue
                                        if matched:
-                                               print 'SUCCES (usage 2): A certificate in the certificate chain (including the end-entity certificate) offered by the server matches the TLSA record'
+                                               print 'SUCCESS (usage 2): A certificate in the certificate chain (including the end-entity certificate) offered by the server matches the TLSA record'
                                                if not args.quiet: print 'The matched certificate has Subject: %s' % cert.get_subject()
                                        else:
                                                print 'FAIL (usage 2): No certificate in the certificate chain (including the end-entity certificate) offered by the server matches the TLSA record'
                                                if pre_exit == 0: pre_exit = 2
 
+                               elif record.usage == 3: # EE cert MUST match
+                                       if verifyCertMatch(record,chain[0]):
+                                               print 'SUCCESS (usage 3): The certificate offered by the server matches the TLSA record'
+                                               if not args.quiet: print 'The matched certificate has Subject: %s' % chain[0].get_subject()
+                                       else:
+                                               print 'FAIL (usage 3): The certificate offered by the server does not match the TLSA record'
+                                               if pre_exit == 0: pre_exit = 2
+
                                # Cleanup, just in case
                                connection.clear()
                                connection.close()
@@ -531,9 +637,15 @@ if __name__ == '__main__':
                                                        input_ok = True
                                        except:
                                                sys.stdout.write('Port %s not numerical or within correct range (1 <= port <= 65535), try again (hit enter for default 443): ' % user_input)
-                       # Get the A records for the host
+                       # Get the address records for the host
                        try:
-                               addresses = getA(args.host, secure=secure)
+                               if check_ipv4 and check_ipv6:
+                                       addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
+                               elif check_ipv4:
+                                       addresses = getA(args.host, secure=secure) 
+                               else:
+                                       addresses = getAAAA(args.host, secure=secure)
+
                        except InsecureLookupException, e:
                                print >> sys.stderr, str(e)
                                sys.exit(1)
@@ -543,22 +655,33 @@ if __name__ == '__main__':
                                # We do the certificate handling here, as M2Crypto keeps segfaulting when try to do stuff with the cert if we don't
                                ctx = SSL.Context()
                                ctx.set_verify(SSL.verify_none, depth=9)
-                               connection = SSL.Connection(ctx)
+                               if check_ipv6 and isinstance(address, AAAARecord):
+                                       sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               if check_ipv4 and isinstance(address, ARecord):
+                                       sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                                       sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                               else:
+                                       sock = None
+                               connection = SSL.Connection(ctx, sock=sock)
                                try:
                                        connection.connect((str(address), int(connection_port)))
                                except SSL.Checker.WrongHost:
                                        pass
+                               except socket.error, e:
+                                       print 'Cannot connect to %s: %s' % (address, str(e))
+                                       continue
 
                                chain = connection.get_peer_cert_chain()
                                for chaincert in chain:
-                                       if int(args.usage) == 1:
+                                       if int(args.usage) == 1 or int(args.usage) == 3:
                                                # The first cert is the end-entity cert
                                                print 'Got a certificate with Subject: %s' % chaincert.get_subject()
                                                cert = chaincert
                                                break
                                        else:
                                                if (int(args.usage) == 0 and chaincert.check_ca()) or int(args.usage) == 2:
-                                                       sys.stdout.write('Got a certificate with the following Subject:\n\t%s.\nUse this as certificate to match? [y/N] ' % chaincert.get_subject())
+                                                       sys.stdout.write('Got a certificate with the following Subject:\n\t%s\nUse this as certificate to match? [y/N] ' % chaincert.get_subject())
                                                        input_ok = False
                                                        while not input_ok:
                                                                user_input = raw_input()
@@ -573,14 +696,21 @@ if __name__ == '__main__':
                                                        break
 
                                if cert: # Print the requested records based on the retrieved certificates
-                                       if args.output == 'b':
+                                       if args.output == 'both':
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'draft', args.usage, args.selector, args.mtype)
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'rfc', args.usage, args.selector, args.mtype)
                                        else:
                                                print genTLSA(args.host, args.protocol, args.port, cert, args.output, args.usage, args.selector, args.mtype)
 
+                               # Clear the cert from memory (to stop M2Crypto from segfaulting)
+                               # And cleanup the connection and context
+                               cert=None
+                               connection.clear()
+                               connection.close()
+                               ctx.close()
+
                else: # Pass the path to the certificate to the genTLSA function
-                       if args.output == 'b':
+                       if args.output == 'both':
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'draft', args.usage, args.selector, args.mtype)
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'rfc', args.usage, args.selector, args.mtype)
                        else: