From 592f5f80d00a341d69440343ad83308006c1e015 Mon Sep 17 00:00:00 2001 From: James Kasten Date: Thu, 4 Oct 2012 15:34:38 -0400 Subject: [PATCH] Added checkpointing and rollback feature for configuration --- trustify/client/CONFIG.py | 2 +- trustify/client/client.py | 23 ++- trustify/client/configurator.py | 244 +++++++++++++++++++++++++------- 3 files changed, 211 insertions(+), 58 deletions(-) diff --git a/trustify/client/CONFIG.py b/trustify/client/CONFIG.py index ba0f23f8a..fc35cf558 100644 --- a/trustify/client/CONFIG.py +++ b/trustify/client/CONFIG.py @@ -18,7 +18,7 @@ OPTIONS_SSL_CONF = CONFIG_DIR + "options-ssl.conf" # Temporary file for challenge virtual hosts APACHE_CHALLENGE_CONF = CONFIG_DIR + "choc_sni_cert_challenge.conf" # Modified files intended to be reset (for challenges/tmp config changes) -MODIFIED_FILES = BACKUP_DIR + "modified_files" +MODIFIED_FILES = WORK_DIR + "modified_files" # Byte size of S and Nonce S_SIZE = 32 NONCE_SIZE = 32 diff --git a/trustify/client/client.py b/trustify/client/client.py index 7cecea095..546011c6e 100644 --- a/trustify/client/client.py +++ b/trustify/client/client.py @@ -231,6 +231,12 @@ def save_key_csr(key, csr): return key_fn, csr_fn +def recognized_ca(issuer): + pass + +def gen_req_from_cert(): + return + def unique_file(default_name, mode = 0777): """ Safely finds a unique file for writing only (by default) @@ -311,9 +317,10 @@ def send_request(key_pem, csr_pem, quiet=curses): return r, k -def handle_verification_response(r, dn, challenge, vhost, key_file, config): +def handle_verification_response(r, dn, challenges, vhost, key_file, config): if r.success.IsInitialized(): - challenge.cleanup() + for chall in challenges: + chall.cleanup() cert_chain_abspath = None with open(cert_file, "w") as f: f.write(r.success.certificate) @@ -366,6 +373,17 @@ def redirect_to_ssl(vhost, config): config.enable_site(redirect_vhost) logger.info("Enabling available site: " + redirect_vhost.file) +def renew(config): + cert_key_pairs = config.get_all_certs_keys() + for tup in cert_key_pairs: + cert = M2Crypto.X509.load_cert(tup[0]) + issuer = cert.get_issuer() + if recognized_ca(issuer): + generate_renewal_req() + + # Wait for response, act accordingly + gen_req_from_cert() + def authenticate(): """ Main call to do DV_SNI validation and deploy the trustify certificate @@ -448,6 +466,7 @@ def authenticate(): k.session = r.session r = decode(do(upstream, k)) logger.debug(r) + # TODO: This needs to be rewritten to handle multiple challenges handle_verification_response(r, dn, challenges[0], vhost, key_file, config) diff --git a/trustify/client/configurator.py b/trustify/client/configurator.py index e86bb1c8a..df621d5ed 100644 --- a/trustify/client/configurator.py +++ b/trustify/client/configurator.py @@ -11,6 +11,7 @@ from trustify.client.CONFIG import SERVER_ROOT, BACKUP_DIR, MODIFIED_FILES #from CONFIG import SERVER_ROOT, BACKUP_DIR, MODIFIED_FILES, REWRITE_HTTPS_ARGS from trustify.client.CONFIG import REWRITE_HTTPS_ARGS from trustify.client import logger +#import logger #TODO - Need an initialization routine... make sure directories exist..ect @@ -46,6 +47,9 @@ class Configurator(object): #for m in self.aug.match("/augeas/load/Httpd/incl"): #self.httpd_incl.append(self.aug.get(m)) + self.save_notes = "" + # new_files is for save checkpoints and to allow reverts + self.new_files = [] self.vhosts = self.get_virtual_hosts() # Add name_server association dict self.assoc = dict() @@ -92,7 +96,13 @@ class Configurator(object): else: self.aug.set(path["cert_chain"][0], cert_chain) - return self.save("Virtual Server - deploying certificate") + self.save_notes += "Changed vhost at %s with addresses of %s" % (vhost.file, vhost.addrs) + self.save_notes += "\tSSLCertificateFile %s\n" % cert + self.save_notes += "\tSSLCertificateKeyFile %s\n" % key + if cert_chain: + self.save_notes += "\tSSLCertificateChainFile %s\n" % cert_chain + # This is a significant operation, make a checkpoint + return self.save("Virtual Server - deploying certificate", False) def choose_virtual_host(self, name, ssl=True): """ @@ -234,7 +244,8 @@ class Configurator(object): logger.warn("ports.conf is not included in your Apache config...") logger.warn("Adding NameVirtualHost directive to httpd.conf") self.add_dir_to_ifmodssl("/files" + SERVER_ROOT + "httpd.conf", "NameVirtualHost", addr) - + + self.save_notes += 'Setting %s to be NameBasedVirtualHost\n' % addr def add_dir_to_ifmodssl(self, aug_conf_path, directive, val): """ @@ -268,6 +279,7 @@ class Configurator(object): logger.debug("No Listen 443 directive found") logger.debug("Setting the Apache Server to Listen on port 443") self.add_dir_to_ifmodssl("/files" + SERVER_ROOT + "ports.conf", "Listen", "443") + self.save_notes += "Added Listen 443 directive to ports.conf\n" # Check for NameVirtualHost # First see if any of the vhost addresses is a _default_ addr @@ -277,6 +289,7 @@ class Configurator(object): if not self.is_name_vhost(default_addr): logger.debug("Setting all VirtualHosts on " + default_addr + " to be name based virtual hosts") self.add_name_vhost(default_addr) + return True # No default addresses... so set each one individually for addr in vhost.addrs: @@ -411,6 +424,8 @@ class Configurator(object): new_file.write("\n") orig_file.close() new_file.close() + # This is used for checkpoints + self.new_files.append(ssl_fp) self.aug.load() # change address to address:443, address:80 @@ -436,6 +451,8 @@ class Configurator(object): # reload configurator vhosts self.vhosts = self.get_virtual_hosts() + self.save_notes += 'Created ssl vhost at %s\n' % ssl_fp + return ssl_fp def redirect_all_ssl(self, ssl_vhost): @@ -463,6 +480,7 @@ class Configurator(object): #Add directives to server self.add_dir(general_v.path, "RewriteEngine", "On") self.add_dir(general_v.path, "RewriteRule", REWRITE_HTTPS_ARGS) + self.save_notes += 'Redirecting host in %s to ssl vhost in %s\n' % (general_v.file, ssl_vhost.file) self.save("Redirect all to ssl") return True, general_v @@ -537,11 +555,15 @@ LogLevel warn \n\ with open(SERVER_ROOT+"sites-available/"+redirect_filename, 'w') as f: f.write(redirect_file) logger.info("Created redirect file: " + redirect_filename) + self.new_files.append(redirect_filename) self.aug.load() new_fp = SERVER_ROOT + "sites-available/" + redirect_filename new_vhost = self.__create_vhost("/files" + new_fp) - self.vhosts.add(self.__create_vhost("/files" + new_fp)) + self.vhosts.add(new_vhost) + + self.save_notes += 'Created a port 80 vhost, %s, for redirection to ssl vhost %s\n' % (new_vhost.file, ssl_vhost.file) + return True, new_vhost def __conflicting_host(self, ssl_vhost): @@ -611,17 +633,28 @@ LogLevel warn \n\ return vh return None - def get_all_certs(self): + def get_all_certs_keys(self): """ - Retrieve all certs on the Apache server - returns: set of file paths + Retrieve all certs and keys set in VirtualHosts on the Apache server + returns: list of tuples with form [(cert, key)] """ - cert_path = self.find_directive("SSLCertificateFile") - file_paths = set() - for p in cert_path: - file_paths.add(self.aug.get(p)) - return file_paths + cert_key_pairs = set() + + for vhost in self.vhosts: + if vhost.ssl: + cert_path = self.find_directive("SSLCertificateFile", None, vhost.path) + key_path = self.find_directive("SSLCertificateKeyFile", None, vhost.path) + # Can be removed once find directive can return ordered results + if cert_path != 1 or key_path != 1: + logger.error("Too many cert or key directives in vhost %s" % vhost.file) + sys.exit(40) + + cert = os.path.abspath(self.aug.get(cert_path[0])) + key = os.path.abspath(self.aug.get(key_path[0])) + cert_key_pairs.add( (cert,key) ) + + return cert_key_pairs def get_file_path(self, vhost_path): # Strip off /files @@ -661,6 +694,7 @@ LogLevel warn \n\ index = vhost.file.rfind("/") os.symlink(vhost.file, SERVER_ROOT + "sites-enabled/" + vhost.file[index:]) vhost.enabled = True + self.save_notes += 'Enabled site %s\n' % vhost.file return True return False @@ -709,41 +743,6 @@ LogLevel warn \n\ #self.aug.add_transform("Httpd.lns", self.httpd_incl, None, self.httpd_excl) self.__add_httpd_transform(file_path) self.aug.load() - - def save(self, mod_conf="Augeas Configuration", reversible=False): - """ - Saves all changes to the configuration files - Backups are stored as *.augsave files - This function is not transactional - TODO: Instead rely on challenge to backup all files before modifications - - mod_conf: string - Error message presented in case of problem - useful for debugging - reversible: boolean - Indicates whether the changes made will be - reversed in the future - """ - try: - self.aug.save() - # Retrieve list of modified files - save_paths = self.aug.match("/augeas/events/saved") - mod_fd = open(MODIFIED_FILES, 'r+') - mod_files = mod_fd.readlines() - for path in save_paths: - # Strip off /files - filename = self.aug.get(path)[6:] - if filename in mod_files: - # Output a warning... hopefully this can be avoided so more - # complex code doesn't have to be written - logger.fatal("Reversible file has been overwritten - %s" % filename) - sys.exit(37) - if reversible: - mod_fd.write(filename + "\n") - mod_fd.close() - return True - except IOError: - logger.error("Unable to save file - %s" % mod_conf) - logger.error("Is the script running as root?") - return False def save_apache_config(self): # Should be safe because it is a protected directory @@ -784,9 +783,6 @@ LogLevel warn \n\ """ This function should reload the users original configuration files for all saves with reversible=True - TODO: This should probably instead pull in files from the - backup directory... move away from augsave so the user doesn't - do anything unexpectedly """ if mod_files is None: try: @@ -834,17 +830,145 @@ LogLevel warn \n\ """ This function will correctly add a transform to augeas The existing augeas.add_transform in python is broken - The augeas.set function it uses appends new includes to the end of the - tree which overrules the exclude parameters. """ lastInclude = self.aug.match("/augeas/load/Httpd/incl [last()]") self.aug.insert(lastInclude[0], "incl", False) self.aug.set("/augeas/load/Httpd/incl[last()]", incl) + + def save(self, mod_conf="Augeas Configuration", reversible=False): + """ + Saves all changes to the configuration files + Backups are stored as *.augsave files + This function is not transactional + TODO: Instead rely on challenge to backup all files before modifications + mod_conf: string - Error message presented in case of problem + useful for debugging + reversible: boolean - Indicates whether the changes made will be + reversed in the future + """ + save_state = self.aug.get("/augeas/save") + self.aug.set("/augeas/save", "noop") + ex_errs = self.aug.match("/augeas//error") + try: + # This is a noop save + self.aug.save() + except: + # Check for the root of save problems + new_errs = self.aug.match("/augeas//error") + # Only print new errors caused by recent save + for err in new_errs: + if err not in ex_errs: + print "Error Saving - " + mod_conf + print "Unable to save - " + err[13:len(err)-6] + print "Attempted Save Notes\n" + print self.save_notes + # Erase Save Notes + self.save_notes = "" + return False + + # Retrieve list of modified files + # Note: Noop saves can cause the file to be listed twice, used set to + # remove this possibility + save_paths = self.aug.match("/augeas/events/saved") + save_files = set() + for p in save_paths: + save_files.add(self.aug.get(p)[6:]) + + for f in save_files: + print f + valid, message = self.check_tempfile_saves(save_files, reversible) + + if not valid: + logger.error(message) + return False + + # Create Checkpoint + if not reversible: + self.create_checkpoint(save_files, mod_conf) + self.aug.set("/augeas/save", save_state) + self.save_notes = "" + del self.new_files[:] + self.aug.save() + + return True + + def create_checkpoint(self, save_files, mod_conf): + cp_dir = BACKUP_DIR + str(time.time()) + try: + #os.makedirs(BACKUP_DIR + datetime.date.today().strftime('%m-%y')) + os.makedirs(cp_dir) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + #Update cp_dir for cleaner path creation + cp_dir = cp_dir + "/" + + with open(cp_dir + "FILEPATHS", 'w') as op_fd: + for idx, filename in enumerate(save_files): + # Tag files with index so multiple files can have same basename + logger.debug("Creating backup of %s" % filename) + shutil.copy2(filename, cp_dir + os.path.basename(filename) + "_" + str(idx)) + op_fd.write(filename + '\n') + + with open(cp_dir + "CHANGES_SINCE", 'w') as notes_fd: + notes_fd.write(self.save_notes) + + # Mark any new files that have been created + # The files will be deleted if the checkpoint is rolledback + if self.new_files: + with open(cp_dir + "NEW_FILES", 'w') as nf_fd: + for filename in self.new_files: + nf_fd.write(filename + '\n') + + def recover_checkpoint(self, rollback = 1): + backups = os.listdir(BACKUP_DIR) + backups.sort() + + if len(backups) < rollback: + logger.error("Unable to rollback %d checkpoints, only %d exist" % (rollback, len(backups))) + + while rollback > 0 and backups: + cp_dir = BACKUP_DIR + backups.pop() + with open(cp_dir + "/FILEPATHS") as f: + filepaths = f.read().splitlines() + for idx, fp in enumerate(filepaths): + shutil.copy2(cp_dir + '/' + os.path.basename(fp) + '_' + str(idx), fp) + try: + # Remove any newly added files if they exist + with open(cp_dir + "/NEW_FILES") as f: + filepaths = f.read().splitlines() + for fp in filepaths: + os.remove(fp) + except: + pass + + try: + shutil.rmtree(cp_dir) + except: + logger.error("Unable to remove directory: %s" % cp_dir) + rollback -= 1 + + self.aug.load() + + def check_tempfile_saves(self, save_files, reversible): + protected_fd = open(MODIFIED_FILES, 'r+') + protected_files = protected_fd.read().splitlines() + for filename in save_files: + if filename in protected_files: + protected_fd.close() + return False, "Attempting to overwrite a reversible file - %s" %filename + # No protected files are trying to be overwritten + if reversible: + for filename in save_files: + protected_fd.write(filename + "\n") + + protected_fd.close() + return True, "Successful" def main(): config = Configurator() - logger.setLogger(sys.stdout) + logger.setLogger(logger.FileLogger(sys.stdout)) logger.setLogLevel(logger.DEBUG) for v in config.vhosts: print v.file @@ -860,11 +984,21 @@ def main(): print "Address:",a, "- Is name vhost?", config.is_name_vhost(a) print config.get_all_names() - + """ + test_file = "/home/james/Desktop/ports_test.conf" + config.parse_file(test_file) - config.parse_file("/etc/apache2/ports_test.conf") + config.aug.insert("/files" + test_file + "/IfModule[1]/arg", "directive", False) + config.aug.set("/files" + test_file + "/IfModule[1]/directive[1]", "Listen") + config.aug.set("/files" + test_file + "/IfModule[1]/directive[1]/arg", "556") + config.aug.set("/files" + test_file + "/IfModule[1]/directive[2]", "Listen") + config.aug.set("/files" + test_file + "/IfModule[1]/directive[2]/arg", "555") - config.restart() + #config.save_notes = "Added listen 431 for test" + #config.new_files.append("/home/james/Desktop/new_file.txt") + #config._save("testing_saves", False) + config.recover_checkpoint(1) + """ """ #config.make_vhost_ssl("/etc/apache2/sites-available/default")