From ee03375f5e34a992f317dcbbd39c5c6bbad3ca0e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavel=20K=C3=A1cha?= <ph@cesnet.cz>
Date: Tue, 8 Dec 2015 16:54:44 +0100
Subject: [PATCH] Warden client now uses system CA stores if not explicitly
 specified

---
 warden3/warden_client/warden_client.py | 28 +++++++++++++++++++++++---
 1 file changed, 25 insertions(+), 3 deletions(-)

diff --git a/warden3/warden_client/warden_client.py b/warden3/warden_client/warden_client.py
index 6861d10..c4a4e23 100644
--- a/warden3/warden_client/warden_client.py
+++ b/warden3/warden_client/warden_client.py
@@ -19,6 +19,12 @@ fix_logging_filename = str if version_info<(2, 7) else lambda(x): x
 
 VERSION = "3.0-beta2"
 
+DEFAULT_CA_STORES = [
+    "/etc/ssl/certs/ca-certificates.crt",       # Deb
+    "/etc/pki/tls/certs/ca-bundle.crt",         # RH
+    "/var/lib/ca-certificates/ca-bundle.pem"    # SuSE
+ ]
+
 class HTTPSConnection(httplib.HTTPSConnection):
     '''
     Overridden to allow peer certificate validation, configuration
@@ -233,9 +239,9 @@ class Client(object):
         self.conn = None
 
         base = path.join(path.dirname(__file__))
-        self.certfile = path.join(base, certfile or "cert.pem")
-        self.keyfile  = path.join(base, keyfile or "key.pem")
-        self.cafile = path.join(base, cafile or "ca.pem")
+        self.certfile = self.get_readable_file(certfile or "cert.pem", base)
+        self.keyfile  = self.get_readable_file(keyfile or "key.pem", base)
+        self.cafile = self.get_readable_file(cafile if cafile is not None else DEFAULT_CA_STORES, base)
         self.timeout = int(timeout)
         self.get_events_limit = int(get_events_limit)
         self.idstore = path.join(base, idstore) if idstore is not None else None
@@ -250,6 +256,22 @@ class Client(object):
         self.getInfo()  # Call to align limits with server opinion
 
 
+    def get_readable_file(self, name, base):
+        names = [name] if isinstance(name, basestring) else name
+        names = [path.join(base, n) for n in names]
+        errors = []
+        for n in names:
+            try:
+                open(n, "r").close()
+                self.logger.debug("Using %s" % n)
+                return n
+            except IOError as e:
+                errors.append(e)
+        for e in errors:
+            self.logger.error(str(e))
+        return names[0] if names else None
+
+
     def init_log(self, errlog, syslog, filelog):
 
         def loglevel(lev):
-- 
GitLab