From 49e2d1aa4f6ca3d18045b543fdb062468fa93924 Mon Sep 17 00:00:00 2001 From: Michael Lazar Date: Mon, 14 Dec 2015 23:37:23 -0800 Subject: [PATCH] Config now loads default values from a file alongside the source. --- MANIFEST.in | 3 +- rtv/__main__.py | 17 ++- rtv/config.py | 154 +++++++++++++--------- rtv/oauth.py | 4 +- rtv/rtv.cfg | 49 +++++++ scripts/build_manpage.py | 4 +- {rtv/templates => scripts}/rtv.1.template | 0 tests/test_config.py | 48 +++++-- 8 files changed, 196 insertions(+), 83 deletions(-) create mode 100644 rtv/rtv.cfg rename {rtv/templates => scripts}/rtv.1.template (100%) diff --git a/MANIFEST.in b/MANIFEST.in index 482bbe9..10eea76 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,4 +4,5 @@ include CONTRIBUTORS.rst include README.rst include LICENSE include rtv.1 -include rtv/templates/index.html +include rtv/rtv.cfg +include rtv/templates/* diff --git a/rtv/__main__.py b/rtv/__main__.py index 99668ca..85c8f24 100644 --- a/rtv/__main__.py +++ b/rtv/__main__.py @@ -9,7 +9,7 @@ import praw import tornado from . import docs -from .config import Config +from .config import Config, copy_default_config from .oauth import OAuthHelper from .terminal import Terminal from .objects import curses_session @@ -38,11 +38,18 @@ def main(): title = 'rtv {0}'.format(__version__) sys.stdout.write('\x1b]2;{0}\x07'.format(title)) - # Attempt to load from the config file first, and then overwrite with any - # provided command line arguments. + args = Config.get_args() + fargs = Config.get_file(args.get('config')) + + # Apply the file config first, then overwrite with any command line args config = Config() - config.from_file() - config.from_args() + config.update(**fargs) + config.update(**args) + + # Copy the default config file and quit + if config['copy_config']: + copy_default_config() + return # Load the browsing history from previous sessions config.load_history() diff --git a/rtv/config.py b/rtv/config.py index b0ba012..0360a9d 100644 --- a/rtv/config.py +++ b/rtv/config.py @@ -3,20 +3,23 @@ from __future__ import unicode_literals import os import codecs +import shutil import argparse +from functools import partial -from six.moves import configparser +from six.moves import configparser, input from . import docs, __version__ -HOME = os.path.expanduser('~') PACKAGE = os.path.dirname(__file__) +HOME = os.path.expanduser('~') +TEMPLATE = os.path.join(PACKAGE, 'templates') +DEFAULT_CONFIG = os.path.join(PACKAGE, 'rtv.cfg') XDG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config')) CONFIG = os.path.join(XDG_HOME, 'rtv', 'rtv.cfg') TOKEN = os.path.join(XDG_HOME, 'rtv', 'refresh-token') HISTORY = os.path.join(XDG_HOME, 'rtv', 'history.log') -TEMPLATE = os.path.join(PACKAGE, 'templates') def build_parser(): @@ -33,22 +36,49 @@ def build_parser(): parser.add_argument( '-l', dest='link', help='full URL of a submission that will be opened on start') + parser.add_argument( + '--log', metavar='FILE', action='store', + help='log HTTP requests') + parser.add_argument( + '--config', metavar='FILE', action='store', + help='Load configuration settings') parser.add_argument( '--ascii', action='store_const', const=True, help='enable ascii-only mode') - parser.add_argument( - '--log', metavar='FILE', action='store', - help='log HTTP requests to a file') parser.add_argument( '--non-persistent', dest='persistent', action='store_const', const=False, - help='Forget all authenticated users when the program exits') + help='Forget the authenticated user when the program exits') parser.add_argument( '--clear-auth', dest='clear_auth', action='store_const', const=True, - help='Remove any saved OAuth tokens before starting') + help='Remove any saved user data before launching') + parser.add_argument( + '--copy-config', dest='copy_config', action='store_const', const=True, + help='Copy the default configuration to {HOME}/.config/rtv/rtv.cfg') return parser +def copy_default_config(filename=CONFIG): + """ + Copy the default configuration file to the user's {HOME}/.config/rtv + """ + + if os.path.exists(filename): + try: + ch = input('File %s already exists, overwrite? y/[n]):' % filename) + if ch not in ('Y', 'y'): + return + except KeyboardInterrupt: + return + + filepath = os.path.dirname(filename) + if not os.path.exists(filepath): + os.makedirs(filepath) + + print('Copying default settings to %s' % filename) + shutil.copy(DEFAULT_CONFIG, filename) + + class OrderedSet(object): """ A simple implementation of an ordered set. A set is used to check @@ -76,36 +106,12 @@ class OrderedSet(object): class Config(object): - DEFAULT = { - 'ascii': False, - 'persistent': True, - 'clear_auth': False, - 'log': None, - 'link': None, - 'subreddit': 'front', - 'history_size': 200, - # https://github.com/reddit/reddit/wiki/OAuth2 - # Client ID is of type "installed app" and the secret should be empty - 'oauth_client_id': 'E2oEtRQfdfAfNQ', - 'oauth_client_secret': 'praw_gapfill', - 'oauth_redirect_uri': 'http://127.0.0.1:65000/', - 'oauth_redirect_port': 65000, - 'oauth_scope': [ - 'edit', 'history', 'identity', 'mysubreddits', 'privatemessages', - 'read', 'report', 'save', 'submit', 'subscribe', 'vote'], - 'template_path': TEMPLATE, - } + def __init__(self, history_file=HISTORY, token_file=TOKEN, **kwargs): - def __init__(self, - config_file=CONFIG, - history_file=HISTORY, - token_file=TOKEN, - **kwargs): - - self.config_file = config_file self.history_file = history_file self.token_file = token_file self.config = kwargs + self.default = self.get_file(DEFAULT_CONFIG) # `refresh_token` and `history` are saved/loaded at separate locations, # so they are treated differently from the rest of the config options. @@ -113,7 +119,10 @@ class Config(object): self.history = OrderedSet() def __getitem__(self, item): - return self.config.get(item, self.DEFAULT.get(item)) + if item in self.config: + return self.config[item] + else: + return self.default.get(item, None) def __setitem__(self, key, value): self.config[key] = value @@ -124,33 +133,6 @@ class Config(object): def update(self, **kwargs): self.config.update(kwargs) - def from_args(self): - parser = build_parser() - args = vars(parser.parse_args()) - # Filter out argument values that weren't supplied - args = {key: val for key, val in args.items() if val is not None} - self.update(**args) - - def from_file(self): - config = configparser.ConfigParser() - if os.path.exists(self.config_file): - with codecs.open(self.config_file, encoding='utf-8') as fp: - config.readfp(fp) - - config_dict = {} - if config.has_section('rtv'): - config_dict = dict(config.items('rtv')) - - # Convert 'true'/'false' to boolean True/False - if 'ascii' in config_dict: - config_dict['ascii'] = config.getboolean('rtv', 'ascii') - if 'clear_auth' in config_dict: - config_dict['clear_auth'] = config.getboolean('rtv', 'clear_auth') - if 'persistent' in config_dict: - config_dict['persistent'] = config.getboolean('rtv', 'persistent') - - self.update(**config_dict) - def load_refresh_token(self): if os.path.exists(self.token_file): with open(self.token_file) as fp: @@ -185,6 +167,54 @@ class Config(object): os.remove(self.history_file) self.history = OrderedSet() + @staticmethod + def get_args(): + """ + Load settings from the command line. + """ + + parser = build_parser() + args = vars(parser.parse_args()) + # Filter out argument values that weren't supplied + return {key: val for key, val in args.items() if val is not None} + + @classmethod + def get_file(cls, filename=None): + """ + Load settings from an rtv configuration file. + """ + + if filename is None: + filename = CONFIG + + config = configparser.ConfigParser() + if os.path.exists(filename): + with codecs.open(filename, encoding='utf-8') as fp: + config.readfp(fp) + + return cls._parse_rtv_file(config) + + @staticmethod + def _parse_rtv_file(config): + + out = {} + if config.has_section('rtv'): + out = dict(config.items('rtv')) + + params = { + 'ascii': partial(config.getboolean, 'rtv'), + 'clear_auth': partial(config.getboolean, 'rtv'), + 'persistent': partial(config.getboolean, 'rtv'), + 'history_size': partial(config.getint, 'rtv'), + 'oauth_redirect_port': partial(config.getint, 'rtv'), + 'oauth_scope': lambda x: out[x].split(',') + } + + for key, func in params.items(): + if key in out: + out[key] = func(key) + return out + @staticmethod def _ensure_filepath(filename): """ diff --git a/rtv/oauth.py b/rtv/oauth.py index 95a88c1..e18fdd7 100644 --- a/rtv/oauth.py +++ b/rtv/oauth.py @@ -7,6 +7,8 @@ import uuid from concurrent.futures import ThreadPoolExecutor from tornado import gen, ioloop, web, httpserver +from .config import TEMPLATE + class OAuthHandler(web.RequestHandler): """ @@ -54,7 +56,7 @@ class OAuthHelper(object): kwargs = {'display': self.term.display, 'params': self.params} routes = [('/', OAuthHandler, kwargs)] self.callback_app = web.Application( - routes, template_path=self.config['template_path']) + routes, template_path=TEMPLATE) self.reddit.set_oauth_app_info( self.config['oauth_client_id'], diff --git a/rtv/rtv.cfg b/rtv/rtv.cfg new file mode 100644 index 0000000..35059b5 --- /dev/null +++ b/rtv/rtv.cfg @@ -0,0 +1,49 @@ +; Reddit Terminal Viewer Configuration File +; https://github.com/michael-lazar/rtv +; +; This file should be placed in $XDG_CONFIG/rtv/rtv.cfg +; If $XDG_CONFIG is not set, use ~/.config/rtv/rtv.cfg + +[rtv] +################## +# General Settings +################## + +; Turn on ascii-only mode to disable all unicode characters. +; This may be necessary for compatibility with some terminal browsers. +ascii = False + +; Enable debugging by logging all HTTP requests and errors to the given file. +;log = /tmp/rtv.log + +; Default subreddit that will be opened when the program launches. +subreddit = front + +; Allow rtv to store reddit authentication credentials between sessions. +persistent = True + +; Clear any stored credentials when the program starts. +clear_auth = False + +; Maximum number of opened links that will be saved in the history file. +history_size = 200 + +################ +# OAuth Settings +################ +; This sections defines the paramaters that will be used during the OAuth +; authentication process. RTV is registered as an "installed app", +; see https://github.com/reddit/reddit/wiki/OAuth2 for more information. + +; These settings are defined at https://www.reddit.com/prefs/apps and should +; not be altered unless you are defining your own developer application. +oauth_client_id = E2oEtRQfdfAfNQ +oauth_client_secret = praw_gapfill +oauth_redirect_uri = http://127.0.0.1:65000/ + +; Port that the rtv webserver will listen on. This should match the redirect +; uri defined above. +oauth_redirect_port = 65000 + +; Access permissions that will be requested. +oauth_scope = edit,history,identity,mysubreddits,privatemessages,read,report,save,submit,subscribe,vote \ No newline at end of file diff --git a/scripts/build_manpage.py b/scripts/build_manpage.py index b9fd431..c6e14d6 100644 --- a/scripts/build_manpage.py +++ b/scripts/build_manpage.py @@ -56,8 +56,8 @@ def main(): data['copyright'] = rtv.__copyright__ # Escape dashes is all of the sections data = {k: v.replace('-', r'\-') for k, v in data.items()} - print('Reading from %s/rtv/templates/rtv.1.template' % ROOT) - with open(os.path.join(ROOT, 'rtv/templates/rtv.1.template')) as fp: + print('Reading from %s/rtv/scripts/rtv.1.template' % ROOT) + with open(os.path.join(ROOT, 'rtv/scripts/rtv.1.template')) as fp: template = fp.read() print('Populating template') out = template.format(**data) diff --git a/rtv/templates/rtv.1.template b/scripts/rtv.1.template similarity index 100% rename from rtv/templates/rtv.1.template rename to scripts/rtv.1.template diff --git a/tests/test_config.py b/tests/test_config.py index 8a55315..d458778 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ import os import codecs from tempfile import NamedTemporaryFile -from rtv.config import Config +from rtv.config import Config, copy_default_config, DEFAULT_CONFIG try: from unittest import mock @@ -13,6 +13,18 @@ except ImportError: import mock +def test_copy_default_config(): + "Make sure the default config file was included in the package" + + with NamedTemporaryFile(suffix='.cfg') as fp: + with mock.patch('rtv.config.input', return_value='y'): + copy_default_config(fp.name) + with open(DEFAULT_CONFIG) as fp_default: + assert fp.read() == fp_default.read() + permissions = os.stat(fp.name).st_mode & 0o777 + assert permissions == 0o664 + + def test_config_interface(): "Test setting and removing values" @@ -20,40 +32,50 @@ def test_config_interface(): assert config['ascii'] is True config['ascii'] = False assert config['ascii'] is False - config['ascii'] = True + config['ascii'] = None + assert config['ascii'] is None del config['ascii'] assert config['ascii'] is False + config.update(subreddit='cfb', new_value=2.0) assert config['subreddit'] == 'cfb' assert config['new_value'] == 2.0 + assert config['link'] is None + assert config['log'] is None -def test_config_from_args(): + +def test_config_get_args(): "Ensure that command line arguments are parsed properly" args = ['rtv', '-s', 'cfb', '-l', 'https://reddit.com/permalink •', '--log', 'logfile.log', + '--config', 'configfile.cfg', '--ascii', '--non-persistent', - '--clear-auth'] + '--clear-auth', + '--copy-config'] with mock.patch('sys.argv', ['rtv']): - config = Config() - config.from_args() + config_dict = Config.get_args() + config = Config(**config_dict) assert config.config == {} with mock.patch('sys.argv', args): - config = Config() - config.from_args() + config_dict = Config.get_args() + + config = Config(**config_dict) assert config['ascii'] is True assert config['subreddit'] == 'cfb' - assert config['link'] == 'https://reddit.com/permalink •' assert config['log'] == 'logfile.log' assert config['ascii'] is True assert config['persistent'] is False assert config['clear_auth'] is True + assert config['link'] == 'https://reddit.com/permalink •' + assert config['config'] == 'configfile.cfg' + assert config['copy_config'] is True def test_config_from_file(): @@ -68,15 +90,17 @@ def test_config_from_file(): 'subreddit': 'cfb'} with NamedTemporaryFile(suffix='.cfg') as fp: - config = Config(config_file=fp.name) - config.from_file() + + fargs = Config.get_file(filename=fp.name) + config = Config(**fargs) assert config.config == {} rows = ['{0}={1}'.format(key, val) for key, val in args.items()] data = '\n'.join(['[rtv]'] + rows) fp.write(codecs.encode(data, 'utf-8')) fp.flush() - config.from_file() + fargs = Config.get_file(filename=fp.name) + config.update(**fargs) assert config.config == args