Fix a segfault when creating records
[public/dnssec-swede-utility.git] / swede
diff --git a/swede b/swede
index cb44cd31b00c9ee5ebb6cded752271645e9a1923..10325c3fd9649491b22dce978f79e121e5da329c 100755 (executable)
--- a/swede
+++ b/swede
@@ -1,6 +1,6 @@
 #!/usr/bin/python
 
-# swede - A tool to create DANE/TLSA (draft 14) records.
+# swede - A tool to create DANE/TLSA (draft 15) records.
 # This tool is really simple and not foolproof, it doesn't check the CN in the
 # Subject field of the certificate. It also doesn't check if the supplied
 # certificate is a CA certificate if usage 1 is specified (or any other
@@ -127,6 +127,7 @@ def getRecords(hostname, rrtype='A', secure=True):
        global resolvconf
        ctx = unbound.ub_ctx()
        ctx.add_ta_file('root.key')
+       ctx.set_option("dlv-anchor-file:", "dlv.isc.org.key")
        # Use the local cache
        if resolvconf and os.path.isfile(resolvconf):
                ctx.resolvconf(resolvconf)
@@ -270,8 +271,8 @@ class TLSARecord:
                except:
                        if self.getPort() != '*':
                                err.append('Port %s not a number' % self.getPort())
-               if not self.usage in [0,1,2]:
-                       err.append('Usage: invalid (%s is not one of 0, 1 or 2)' % self.usage)
+               if not self.usage in [0,1,2,3]:
+                       err.append('Usage: invalid (%s is not one of 0, 1, 2 or 3)' % self.usage)
                if not self.selector in [0,1]:
                        err.append('Selector: invalid (%s is not one of 0 or 1)' % self.selector)
                if not self.mtype in [0,1,2]:
@@ -368,7 +369,7 @@ if __name__ == '__main__':
        #parser.add_argument('-6', dest='ipv6', action='store_true',help='use ipv6 networking only')
        parser.add_argument('--insecure', action='store_true', default=False, help='Allow use of non-dnssec secured answers')
        parser.add_argument('--resolvconf', metavar='/PATH/TO/RESOLV.CONF', action='store', default='', help='Use a recursive resolver from resolv.conf')
-       parser.add_argument('-v', '--version', action='version', version='%(prog)s v0.1', help='show version and exit')
+       parser.add_argument('-v', '--version', action='version', version='%(prog)s v0.2', help='show version and exit')
        parser.add_argument('host', metavar="hostname")
 
        parser_verify.add_argument('--port', '-p', action='store', default='443', help='The port, or \'*\' where running TLS is located (default: %(default)s).')
@@ -383,7 +384,7 @@ if __name__ == '__main__':
        parser_create.add_argument('--output', '-o', action='store', default='draft', choices=['draft','rfc','both'], help='The type of output. Draft (private RRtype, 65468), RFC (TLSA) or both (default: %(default)s).')
 
        # Usage of the certificate
-       parser_create.add_argument('--usage', '-u', action='store', type=int, default=1, choices=[0,1,2], help='The Usage of the Certificate for Association. \'0\' for CA, \'1\' for End Entity, \'2\' for trust-anchor (default: %(default)s).')
+       parser_create.add_argument('--usage', '-u', action='store', type=int, default=1, choices=[0,1,2,3], help='The Usage of the Certificate for Association. \'0\' for CA, \'1\' for End Entity, \'2\' for trust-anchor, \'3\' for ONLY End-Entity match (default: %(default)s).')
        parser_create.add_argument('--selector', '-s', action='store', type=int, default=0, choices=[0,1], help='The Selector for the Certificate for Association. \'0\' for Full Certificate, \'1\' for SubjectPublicKeyInfo (default: %(default)s).')
        parser_create.add_argument('--mtype', '-m', action='store', type=int, default=1, choices=[0,1,2], help='The Matching Type of the Certificate for Association. \'0\' for Exact match, \'1\' for SHA-256 hash, \'2\' for SHA-512 (default: %(default)s).')
 
@@ -415,9 +416,9 @@ if __name__ == '__main__':
                        # First, check if the first three fields have correct values.
                        if not args.quiet:
                                print 'Received the following record for name %s:' % record.name
-                               print '\tUsage:\t\t\t\t%d (%s)' % (record.usage, {0:'CA Constraint', 1:'End-Entity Constraint', 2:'Trust Anchor'}[record.usage])
-                               print '\tSelector:\t\t\t%d (%s)' % (record.selector, {0:'Certificate', 1:'SubjectPublicKeyInfo'}[record.selector])
-                               print '\tMatching Type:\t\t\t%d (%s)' % (record.mtype, {0:'Full Certificate', 1:'SHA-256', 2:'SHA-512'}[record.mtype])
+                               print '\tUsage:\t\t\t\t%d (%s)' % (record.usage, {0:'CA Constraint', 1:'End-Entity Constraint + chain to CA', 2:'Trust Anchor', 3:'End-Entity'}.get(record.usage, 'INVALID'))
+                               print '\tSelector:\t\t\t%d (%s)' % (record.selector, {0:'Certificate', 1:'SubjectPublicKeyInfo'}.get(record.selector, 'INVALID'))
+                               print '\tMatching Type:\t\t\t%d (%s)' % (record.mtype, {0:'Full Certificate', 1:'SHA-256', 2:'SHA-512'}.get(record.mtype, 'INVALID'))
                                print '\tCertificate for Association:\t%s' % record.cert
 
                        try:
@@ -481,6 +482,7 @@ if __name__ == '__main__':
                                elif record.usage == 0: # CA constraint
                                        matched = False
                                        # Remove the first (= End-Entity cert) from the chain
+                                       chain = chain[1:]
                                        for cert in chain:
                                                if verifyCertMatch(record, cert):
                                                        matched = True
@@ -500,9 +502,16 @@ if __name__ == '__main__':
                                                print 'FAIL (Usage 0): No certificate in the certificate chain offered by the server matches the TLSA record'
                                                if pre_exit == 0: pre_exit = 2
 
-                               elif record.usage == 2: # Usage 2, ANY cert in the chain must match (aka 'pick any')
+                               elif record.usage == 2: # Usage 2, use the cert in the record as trust anchor
+                                       #FIXME: doesnt comply to the spec
                                        matched = False
+                                       previous_issuer = None
                                        for cert in chain:
+                                               if previous_issuer:
+                                                       if not str(previous_issuer) == str(cert.get_subject()): # The chain cannot be valid
+                                                               print "FAIL: Certificates don't chain"
+                                                               break
+                                                       previous_issuer = cert.get_issuer()
                                                if verifyCertMatch(record, cert):
                                                        matched = True
                                                        continue
@@ -513,6 +522,14 @@ if __name__ == '__main__':
                                                print 'FAIL (usage 2): No certificate in the certificate chain (including the end-entity certificate) offered by the server matches the TLSA record'
                                                if pre_exit == 0: pre_exit = 2
 
+                               elif record.usage == 3: # EE cert MUST match
+                                       if verifyCertMatch(record,chain[0]):
+                                               print 'SUCCESS (usage 3): The certificate offered by the server matches the TLSA record'
+                                               if not args.quiet: print 'The matched certificate has Subject: %s' % chain[0].get_subject()
+                                       else:
+                                               print 'FAIL (usage 3): The certificate offered by the server does not match the TLSA record'
+                                               if pre_exit == 0: pre_exit = 2
+
                                # Cleanup, just in case
                                connection.clear()
                                connection.close()
@@ -566,14 +583,14 @@ if __name__ == '__main__':
 
                                chain = connection.get_peer_cert_chain()
                                for chaincert in chain:
-                                       if int(args.usage) == 1:
+                                       if int(args.usage) == 1 or int(args.usage) == 3:
                                                # The first cert is the end-entity cert
                                                print 'Got a certificate with Subject: %s' % chaincert.get_subject()
                                                cert = chaincert
                                                break
                                        else:
                                                if (int(args.usage) == 0 and chaincert.check_ca()) or int(args.usage) == 2:
-                                                       sys.stdout.write('Got a certificate with the following Subject:\n\t%s.\nUse this as certificate to match? [y/N] ' % chaincert.get_subject())
+                                                       sys.stdout.write('Got a certificate with the following Subject:\n\t%s\nUse this as certificate to match? [y/N] ' % chaincert.get_subject())
                                                        input_ok = False
                                                        while not input_ok:
                                                                user_input = raw_input()
@@ -594,6 +611,13 @@ if __name__ == '__main__':
                                        else:
                                                print genTLSA(args.host, args.protocol, args.port, cert, args.output, args.usage, args.selector, args.mtype)
 
+                               # Clear the cert from memory (to stop M2Crypto from segfaulting)
+                               # And cleanup the connection and context
+                               cert=None
+                               connection.clear()
+                               connection.close()
+                               ctx.close()
+
                else: # Pass the path to the certificate to the genTLSA function
                        if args.output == 'both':
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'draft', args.usage, args.selector, args.mtype)