When printing to stderr, redirect correctly
[public/dnssec-swede-utility.git] / swede
diff --git a/swede b/swede
index 643c079d55107dc97c3792040a6d6589e6971c87..8e329127dbce1e3a25356f1d7cd492ebd56d1ec1 100755 (executable)
--- a/swede
+++ b/swede
@@ -24,6 +24,7 @@ from binascii import a2b_hex, b2a_hex
 from hashlib import sha256, sha512
 from ipaddr import IPv4Address, IPv6Address
 
+
 def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, selector=0, mtype=1):
        """This function generates a TLSARecord object using the data passed in the parameters,
        it then validates the record and returns the RR as a string.
@@ -58,7 +59,14 @@ def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, sele
 
 def getA(hostname, secure=True):
        """Gets a list of A records for hostname, returns a list of ARecords"""
-       records = getRecords(hostname, rrtype='A', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='A', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(ARecord(hostname, str(IPv4Address(int(b2a_hex(record),16)))))
@@ -66,7 +74,14 @@ def getA(hostname, secure=True):
 
 def getAAAA(hostname, secure=True):
        """Gets a list of A records for hostname, returns a list of AAAARecords"""
-       records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(AAAARecord(hostname, str(IPv6Address(int(b2a_hex(record),16)))))
@@ -109,8 +124,12 @@ def getVerificationErrorReason(num):
 
 def getRecords(hostname, rrtype='A', secure=True):
        """Do a lookup of a name and a rrtype, returns a list of binary coded strings. Only queries for rr_class IN."""
+       global resolvconf
        ctx = unbound.ub_ctx()
        ctx.add_ta_file('root.key')
+       # Use the local cache
+       if resolvconf and os.path.isfile(resolvconf):
+               ctx.resolvconf(resolvconf)
 
        if type(rrtype) == str:
                if 'RR_TYPE_' + rrtype in dir(unbound):
@@ -131,7 +150,7 @@ def getRecords(hostname, rrtype='A', secure=True):
                # If we are here the data was either secure or insecure data is accepted
                return result.data.raw
        else:
-               raise Exception('Error: Unsuccesful lookup or no data returned.')
+               raise DNSLookupError('Unsuccesful lookup or no data returned for rrtype %s.' % rrtype)
 
 def getHash(certificate, mtype):
        """Hashes the certificate based on the mtype.
@@ -166,6 +185,9 @@ def getTLSA(hostname, port=443, protocol='tcp', secure=True):
        except InsecureLookupException, e:
                print str(e)
                sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                hexdata = b2a_hex(record)
@@ -273,7 +295,7 @@ class TLSARecord:
 
        def isNameValid(self):
                """Check if the name if in the correct format"""
-               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([a-z0-9]*\.){2,}$', self.name):
+               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([-a-z0-9]*\.){2,}$', self.name):
                        return False
                return True
 
@@ -323,16 +345,13 @@ class AAAARecord:
 
 # Exceptions
 class RecordValidityException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
 
 class InsecureLookupException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
+
+class DNSLookupError(Exception):
+       pass
 
 if __name__ == '__main__':
        import argparse
@@ -348,6 +367,7 @@ if __name__ == '__main__':
        #parser.add_argument('-4', dest='ipv4', action='store_true',help='use ipv4 networking only')
        #parser.add_argument('-6', dest='ipv6', action='store_true',help='use ipv6 networking only')
        parser.add_argument('--insecure', action='store_true', default=False, help='Allow use of non-dnssec secured answers')
+       parser.add_argument('--resolvconf', metavar='/PATH/TO/RESOLV.CONF', action='store', default='', help='Use a recursive resolver from resolv.conf')
        parser.add_argument('-v', '--version', action='version', version='%(prog)s v0.1', help='show version and exit')
        parser.add_argument('host', metavar="hostname")
 
@@ -372,6 +392,16 @@ if __name__ == '__main__':
        if args.host[-1] != '.':
                args.host += '.'
 
+       global resolvconf
+       if args.resolvconf:
+               if os.path.isfile(args.resolvconf):
+                       resolvconf = args.resolvconf
+               else:
+                       print >> sys.stdout, '%s is not a file. Unable to use it as resolv.conf' % args.resolvconf
+                       sys.exit(1)
+       else:
+               resolvconf = None
+
        # not operations are fun!
        secure = not args.insecure
 
@@ -393,8 +423,8 @@ if __name__ == '__main__':
                        try:
                                record.isValid(raiseException=True)
                        except RecordValidityException, e:
-                               print sys.stderr, 'Error: %s' % str(e)
-                               sys.exit(1)
+                               print >> sys.stderr, 'Error: %s' % str(e)
+                               continue
                        else:
                                if not args.quiet:
                                        print 'This record is valid (well-formed).'
@@ -558,14 +588,14 @@ if __name__ == '__main__':
                                                        break
 
                                if cert: # Print the requested records based on the retrieved certificates
-                                       if args.output == 'b':
+                                       if args.output == 'both':
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'draft', args.usage, args.selector, args.mtype)
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'rfc', args.usage, args.selector, args.mtype)
                                        else:
                                                print genTLSA(args.host, args.protocol, args.port, cert, args.output, args.usage, args.selector, args.mtype)
 
                else: # Pass the path to the certificate to the genTLSA function
-                       if args.output == 'b':
+                       if args.output == 'both':
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'draft', args.usage, args.selector, args.mtype)
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'rfc', args.usage, args.selector, args.mtype)
                        else: