When printing to stderr, redirect correctly
[public/dnssec-swede-utility.git] / swede
diff --git a/swede b/swede
index 4be39b51e1c7d060ffe765f8e3def42d785ece0e..8e329127dbce1e3a25356f1d7cd492ebd56d1ec1 100755 (executable)
--- a/swede
+++ b/swede
@@ -24,6 +24,7 @@ from binascii import a2b_hex, b2a_hex
 from hashlib import sha256, sha512
 from ipaddr import IPv4Address, IPv6Address
 
+
 def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, selector=0, mtype=1):
        """This function generates a TLSARecord object using the data passed in the parameters,
        it then validates the record and returns the RR as a string.
@@ -58,7 +59,14 @@ def genTLSA(hostname, protocol, port, certificate, output='draft', usage=1, sele
 
 def getA(hostname, secure=True):
        """Gets a list of A records for hostname, returns a list of ARecords"""
-       records = getRecords(hostname, rrtype='A', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='A', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(ARecord(hostname, str(IPv4Address(int(b2a_hex(record),16)))))
@@ -66,7 +74,14 @@ def getA(hostname, secure=True):
 
 def getAAAA(hostname, secure=True):
        """Gets a list of A records for hostname, returns a list of AAAARecords"""
-       records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       try:
+               records = getRecords(hostname, rrtype='AAAA', secure=secure)
+       except InsecureLookupException, e:
+               print str(e)
+               sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                ret.append(AAAARecord(hostname, str(IPv6Address(int(b2a_hex(record),16)))))
@@ -135,7 +150,7 @@ def getRecords(hostname, rrtype='A', secure=True):
                # If we are here the data was either secure or insecure data is accepted
                return result.data.raw
        else:
-               raise Exception('Error: Unsuccesful lookup or no data returned.')
+               raise DNSLookupError('Unsuccesful lookup or no data returned for rrtype %s.' % rrtype)
 
 def getHash(certificate, mtype):
        """Hashes the certificate based on the mtype.
@@ -170,6 +185,9 @@ def getTLSA(hostname, port=443, protocol='tcp', secure=True):
        except InsecureLookupException, e:
                print str(e)
                sys.exit(1)
+       except DNSLookupError, e:
+               print 'Unable to resolve %s: %s' % (hostname, str(e))
+               sys.exit(1)
        ret = []
        for record in records:
                hexdata = b2a_hex(record)
@@ -277,7 +295,7 @@ class TLSARecord:
 
        def isNameValid(self):
                """Check if the name if in the correct format"""
-               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([a-z0-9]*\.){2,}$', self.name):
+               if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([-a-z0-9]*\.){2,}$', self.name):
                        return False
                return True
 
@@ -327,16 +345,13 @@ class AAAARecord:
 
 # Exceptions
 class RecordValidityException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
 
 class InsecureLookupException(Exception):
-       def __init__(self, value):
-               self.value = value
-       def __str__(self):
-               return self.value
+       pass
+
+class DNSLookupError(Exception):
+       pass
 
 if __name__ == '__main__':
        import argparse
@@ -408,8 +423,8 @@ if __name__ == '__main__':
                        try:
                                record.isValid(raiseException=True)
                        except RecordValidityException, e:
-                               print sys.stderr, 'Error: %s' % str(e)
-                               sys.exit(1)
+                               print >> sys.stderr, 'Error: %s' % str(e)
+                               continue
                        else:
                                if not args.quiet:
                                        print 'This record is valid (well-formed).'
@@ -573,14 +588,14 @@ if __name__ == '__main__':
                                                        break
 
                                if cert: # Print the requested records based on the retrieved certificates
-                                       if args.output == 'b':
+                                       if args.output == 'both':
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'draft', args.usage, args.selector, args.mtype)
                                                print genTLSA(args.host, args.protocol, args.port, cert, 'rfc', args.usage, args.selector, args.mtype)
                                        else:
                                                print genTLSA(args.host, args.protocol, args.port, cert, args.output, args.usage, args.selector, args.mtype)
 
                else: # Pass the path to the certificate to the genTLSA function
-                       if args.output == 'b':
+                       if args.output == 'both':
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'draft', args.usage, args.selector, args.mtype)
                                print genTLSA(args.host, args.protocol, args.port, args.certificate, 'rfc', args.usage, args.selector, args.mtype)
                        else: