diff --git a/warden3/warden_client/warden_client.py b/warden3/warden_client/warden_client.py index 80f5b4ac3c9052a2787fbd3cf9a52dcce75e42d8..297d2fad3b0c94a28f0b3c5413d2d59edc5eff09 100644 --- a/warden3/warden_client/warden_client.py +++ b/warden3/warden_client/warden_client.py @@ -254,8 +254,9 @@ class Client(object): def sendRequest(self, func="", payload=None, **kwargs): - kwargs["client"] = self.name - if self.secret is not None: + if self.secret is None: + kwargs["client"] = self.name + else: kwargs["secret"] = self.secret if kwargs: diff --git a/warden3/warden_client/warden_curl_test.sh b/warden3/warden_client/warden_curl_test.sh index e513ce3146d4b18c851d87728711b4b17911c78e..e77c71c7baadb8d522be6cfb072b944781f12309 100755 --- a/warden3/warden_client/warden_curl_test.sh +++ b/warden3/warden_client/warden_curl_test.sh @@ -44,7 +44,7 @@ curl \ "$url/getEvents?client=$client" echo -echo "Test 403 - no client" +echo "Test 403 - no client, no secret" curl \ --key $keyfile \ --cert $certfile \ @@ -64,6 +64,36 @@ curl \ "$url/getEvents?client=asdf.blefub" echo +echo "Test 403 - wrong client, right secret" +curl \ + --key $keyfile \ + --cert $certfile \ + --cacert $cafile \ + --connect-timeout 3 \ + --request POST \ + "$url/getEvents?client=asdf.blefub&secret=$secret" +echo + +echo "Test 403 - right client, wrong secret" +curl \ + --key $keyfile \ + --cert $certfile \ + --cacert $cafile \ + --connect-timeout 3 \ + --request POST \ + "$url/getEvents?client=$client&secret=ASDFblefub" +echo + +echo "Test - no client, but secret, should be ok" +curl \ + --key $keyfile \ + --cert $certfile \ + --cacert $cafile \ + --connect-timeout 3 \ + --request POST \ + "$url/getEvents?secret=$secret" +echo + echo "Test Deserialization" curl \ --key $keyfile \ diff --git a/warden3/warden_server/warden_server.py b/warden3/warden_server/warden_server.py index 404d3675f268fa0c780ca139c1de340be6782eef..9c34759a69faeb0d99e020e91cb9b29d27e6823e 100755 --- a/warden3/warden_server/warden_server.py +++ b/warden3/warden_server/warden_server.py @@ -286,37 +286,28 @@ class X509Authenticator(NoAuthenticator): def authenticate (self, env, args): - try: - identity = args["client"][0] - except KeyError: - logging.info("authenticate: bad or missing client argument") - return None - try: cert_names = self.get_cert_dns_names(env["SSL_CLIENT_CERT"]) except: logging.info("authenticate: cannot get or parse certificate from env") return None - - client = self.db.get_client_by_name(identity, cert_names) + + identity = args.get("client", [None])[0] + secret = args.get("secret", [None])[0] + args["secret"] = ["..."] # Prevent to spill it over logs + + client = self.db.get_client_by_name(cert_names, identity, secret) if not client: - logging.info("authenticate: client not found") + logging.info("authenticate: client not found by identity: \"%s\", secret: %s, cert_names: %s" % ( + identity, "..." if secret else "None", str(cert_names))) return None # Clients with 'secret' set muset get authorized by it. # No secret turns auth off for this particular client. - if client.secret is not None: - try: - secret = args["secret"][0] - except KeyError: - logging.info("authenticate: missing secret argument") - return None - if secret != client.secret: - logging.info("authenticate: wrong credentials") - return None - # Already checked, prevent to spill it over logs - args["secret"] = ["..."] + if client.secret is not None and secret is None: + logging.info("authenticate: missing secret argument") + return None logging.info("authenticate: %s" % str(client)) @@ -421,19 +412,26 @@ class MySQL(ObjectReq): type(self).__name__, type(self.req).__name__, self.host, self.user, self.dbname, self.port, self.catmap_filename, self.tagmap_filename) - def get_client_by_name(self, identity, cert_names): - format_strings = ','.join(['%s'] * len(cert_names)) - query = "SELECT id, registered, requestor, hostname, service, note, identity, secret, `read`, debug, `write`, test FROM clients WHERE valid = 1 AND identity = %%s AND hostname IN (%s)" % format_strings - self.crs.execute(query, [identity] + cert_names) + def get_client_by_name(self, cert_names, identity=None, secret=None): + query = ["SELECT id, registered, requestor, hostname, service, note, identity, secret, `read`, debug, `write`, test FROM clients WHERE valid = 1"] + params = [] + if identity: + query.append(" AND identity = %s") + params.append(identity) + if secret: + query.append(" AND secret = %s") + params.append(secret) + query.append(" AND hostname IN (%s)" % ','.join(['%s'] * len(cert_names))) + params.extend(cert_names) + self.crs.execute("".join(query), params) rows = self.crs.fetchall() if len(rows)>1: - logging.warn("get_client_by_name: query returned more than one result: %s" % str(rows)) + logging.warn("get_client_by_name: query returned more than one result: %s" % ", ".join( + [str(Client(**row)) for row in rows])) return None - client = Client(**rows[0]) if rows else None - - return client + return Client(**rows[0]) if rows else None def get_debug(self):