Config now loads default values from a file alongside the source.

This commit is contained in:
Michael Lazar
2015-12-14 23:37:23 -08:00
parent b487f70e48
commit 49e2d1aa4f
8 changed files with 196 additions and 83 deletions

View File

@@ -4,4 +4,5 @@ include CONTRIBUTORS.rst
include README.rst include README.rst
include LICENSE include LICENSE
include rtv.1 include rtv.1
include rtv/templates/index.html include rtv/rtv.cfg
include rtv/templates/*

View File

@@ -9,7 +9,7 @@ import praw
import tornado import tornado
from . import docs from . import docs
from .config import Config from .config import Config, copy_default_config
from .oauth import OAuthHelper from .oauth import OAuthHelper
from .terminal import Terminal from .terminal import Terminal
from .objects import curses_session from .objects import curses_session
@@ -38,11 +38,18 @@ def main():
title = 'rtv {0}'.format(__version__) title = 'rtv {0}'.format(__version__)
sys.stdout.write('\x1b]2;{0}\x07'.format(title)) sys.stdout.write('\x1b]2;{0}\x07'.format(title))
# Attempt to load from the config file first, and then overwrite with any args = Config.get_args()
# provided command line arguments. fargs = Config.get_file(args.get('config'))
# Apply the file config first, then overwrite with any command line args
config = Config() config = Config()
config.from_file() config.update(**fargs)
config.from_args() config.update(**args)
# Copy the default config file and quit
if config['copy_config']:
copy_default_config()
return
# Load the browsing history from previous sessions # Load the browsing history from previous sessions
config.load_history() config.load_history()

View File

@@ -3,20 +3,23 @@ from __future__ import unicode_literals
import os import os
import codecs import codecs
import shutil
import argparse import argparse
from functools import partial
from six.moves import configparser from six.moves import configparser, input
from . import docs, __version__ from . import docs, __version__
HOME = os.path.expanduser('~')
PACKAGE = os.path.dirname(__file__) PACKAGE = os.path.dirname(__file__)
HOME = os.path.expanduser('~')
TEMPLATE = os.path.join(PACKAGE, 'templates')
DEFAULT_CONFIG = os.path.join(PACKAGE, 'rtv.cfg')
XDG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config')) XDG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config'))
CONFIG = os.path.join(XDG_HOME, 'rtv', 'rtv.cfg') CONFIG = os.path.join(XDG_HOME, 'rtv', 'rtv.cfg')
TOKEN = os.path.join(XDG_HOME, 'rtv', 'refresh-token') TOKEN = os.path.join(XDG_HOME, 'rtv', 'refresh-token')
HISTORY = os.path.join(XDG_HOME, 'rtv', 'history.log') HISTORY = os.path.join(XDG_HOME, 'rtv', 'history.log')
TEMPLATE = os.path.join(PACKAGE, 'templates')
def build_parser(): def build_parser():
@@ -33,22 +36,49 @@ def build_parser():
parser.add_argument( parser.add_argument(
'-l', dest='link', '-l', dest='link',
help='full URL of a submission that will be opened on start') help='full URL of a submission that will be opened on start')
parser.add_argument(
'--log', metavar='FILE', action='store',
help='log HTTP requests')
parser.add_argument(
'--config', metavar='FILE', action='store',
help='Load configuration settings')
parser.add_argument( parser.add_argument(
'--ascii', action='store_const', const=True, '--ascii', action='store_const', const=True,
help='enable ascii-only mode') help='enable ascii-only mode')
parser.add_argument(
'--log', metavar='FILE', action='store',
help='log HTTP requests to a file')
parser.add_argument( parser.add_argument(
'--non-persistent', dest='persistent', action='store_const', '--non-persistent', dest='persistent', action='store_const',
const=False, const=False,
help='Forget all authenticated users when the program exits') help='Forget the authenticated user when the program exits')
parser.add_argument( parser.add_argument(
'--clear-auth', dest='clear_auth', action='store_const', const=True, '--clear-auth', dest='clear_auth', action='store_const', const=True,
help='Remove any saved OAuth tokens before starting') help='Remove any saved user data before launching')
parser.add_argument(
'--copy-config', dest='copy_config', action='store_const', const=True,
help='Copy the default configuration to {HOME}/.config/rtv/rtv.cfg')
return parser return parser
def copy_default_config(filename=CONFIG):
"""
Copy the default configuration file to the user's {HOME}/.config/rtv
"""
if os.path.exists(filename):
try:
ch = input('File %s already exists, overwrite? y/[n]):' % filename)
if ch not in ('Y', 'y'):
return
except KeyboardInterrupt:
return
filepath = os.path.dirname(filename)
if not os.path.exists(filepath):
os.makedirs(filepath)
print('Copying default settings to %s' % filename)
shutil.copy(DEFAULT_CONFIG, filename)
class OrderedSet(object): class OrderedSet(object):
""" """
A simple implementation of an ordered set. A set is used to check A simple implementation of an ordered set. A set is used to check
@@ -76,36 +106,12 @@ class OrderedSet(object):
class Config(object): class Config(object):
DEFAULT = { def __init__(self, history_file=HISTORY, token_file=TOKEN, **kwargs):
'ascii': False,
'persistent': True,
'clear_auth': False,
'log': None,
'link': None,
'subreddit': 'front',
'history_size': 200,
# https://github.com/reddit/reddit/wiki/OAuth2
# Client ID is of type "installed app" and the secret should be empty
'oauth_client_id': 'E2oEtRQfdfAfNQ',
'oauth_client_secret': 'praw_gapfill',
'oauth_redirect_uri': 'http://127.0.0.1:65000/',
'oauth_redirect_port': 65000,
'oauth_scope': [
'edit', 'history', 'identity', 'mysubreddits', 'privatemessages',
'read', 'report', 'save', 'submit', 'subscribe', 'vote'],
'template_path': TEMPLATE,
}
def __init__(self,
config_file=CONFIG,
history_file=HISTORY,
token_file=TOKEN,
**kwargs):
self.config_file = config_file
self.history_file = history_file self.history_file = history_file
self.token_file = token_file self.token_file = token_file
self.config = kwargs self.config = kwargs
self.default = self.get_file(DEFAULT_CONFIG)
# `refresh_token` and `history` are saved/loaded at separate locations, # `refresh_token` and `history` are saved/loaded at separate locations,
# so they are treated differently from the rest of the config options. # so they are treated differently from the rest of the config options.
@@ -113,7 +119,10 @@ class Config(object):
self.history = OrderedSet() self.history = OrderedSet()
def __getitem__(self, item): def __getitem__(self, item):
return self.config.get(item, self.DEFAULT.get(item)) if item in self.config:
return self.config[item]
else:
return self.default.get(item, None)
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.config[key] = value self.config[key] = value
@@ -124,33 +133,6 @@ class Config(object):
def update(self, **kwargs): def update(self, **kwargs):
self.config.update(kwargs) self.config.update(kwargs)
def from_args(self):
parser = build_parser()
args = vars(parser.parse_args())
# Filter out argument values that weren't supplied
args = {key: val for key, val in args.items() if val is not None}
self.update(**args)
def from_file(self):
config = configparser.ConfigParser()
if os.path.exists(self.config_file):
with codecs.open(self.config_file, encoding='utf-8') as fp:
config.readfp(fp)
config_dict = {}
if config.has_section('rtv'):
config_dict = dict(config.items('rtv'))
# Convert 'true'/'false' to boolean True/False
if 'ascii' in config_dict:
config_dict['ascii'] = config.getboolean('rtv', 'ascii')
if 'clear_auth' in config_dict:
config_dict['clear_auth'] = config.getboolean('rtv', 'clear_auth')
if 'persistent' in config_dict:
config_dict['persistent'] = config.getboolean('rtv', 'persistent')
self.update(**config_dict)
def load_refresh_token(self): def load_refresh_token(self):
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
with open(self.token_file) as fp: with open(self.token_file) as fp:
@@ -185,6 +167,54 @@ class Config(object):
os.remove(self.history_file) os.remove(self.history_file)
self.history = OrderedSet() self.history = OrderedSet()
@staticmethod
def get_args():
"""
Load settings from the command line.
"""
parser = build_parser()
args = vars(parser.parse_args())
# Filter out argument values that weren't supplied
return {key: val for key, val in args.items() if val is not None}
@classmethod
def get_file(cls, filename=None):
"""
Load settings from an rtv configuration file.
"""
if filename is None:
filename = CONFIG
config = configparser.ConfigParser()
if os.path.exists(filename):
with codecs.open(filename, encoding='utf-8') as fp:
config.readfp(fp)
return cls._parse_rtv_file(config)
@staticmethod
def _parse_rtv_file(config):
out = {}
if config.has_section('rtv'):
out = dict(config.items('rtv'))
params = {
'ascii': partial(config.getboolean, 'rtv'),
'clear_auth': partial(config.getboolean, 'rtv'),
'persistent': partial(config.getboolean, 'rtv'),
'history_size': partial(config.getint, 'rtv'),
'oauth_redirect_port': partial(config.getint, 'rtv'),
'oauth_scope': lambda x: out[x].split(',')
}
for key, func in params.items():
if key in out:
out[key] = func(key)
return out
@staticmethod @staticmethod
def _ensure_filepath(filename): def _ensure_filepath(filename):
""" """

View File

@@ -7,6 +7,8 @@ import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from tornado import gen, ioloop, web, httpserver from tornado import gen, ioloop, web, httpserver
from .config import TEMPLATE
class OAuthHandler(web.RequestHandler): class OAuthHandler(web.RequestHandler):
""" """
@@ -54,7 +56,7 @@ class OAuthHelper(object):
kwargs = {'display': self.term.display, 'params': self.params} kwargs = {'display': self.term.display, 'params': self.params}
routes = [('/', OAuthHandler, kwargs)] routes = [('/', OAuthHandler, kwargs)]
self.callback_app = web.Application( self.callback_app = web.Application(
routes, template_path=self.config['template_path']) routes, template_path=TEMPLATE)
self.reddit.set_oauth_app_info( self.reddit.set_oauth_app_info(
self.config['oauth_client_id'], self.config['oauth_client_id'],

49
rtv/rtv.cfg Normal file
View File

@@ -0,0 +1,49 @@
; Reddit Terminal Viewer Configuration File
; https://github.com/michael-lazar/rtv
;
; This file should be placed in $XDG_CONFIG/rtv/rtv.cfg
; If $XDG_CONFIG is not set, use ~/.config/rtv/rtv.cfg
[rtv]
##################
# General Settings
##################
; Turn on ascii-only mode to disable all unicode characters.
; This may be necessary for compatibility with some terminal browsers.
ascii = False
; Enable debugging by logging all HTTP requests and errors to the given file.
;log = /tmp/rtv.log
; Default subreddit that will be opened when the program launches.
subreddit = front
; Allow rtv to store reddit authentication credentials between sessions.
persistent = True
; Clear any stored credentials when the program starts.
clear_auth = False
; Maximum number of opened links that will be saved in the history file.
history_size = 200
################
# OAuth Settings
################
; This sections defines the paramaters that will be used during the OAuth
; authentication process. RTV is registered as an "installed app",
; see https://github.com/reddit/reddit/wiki/OAuth2 for more information.
; These settings are defined at https://www.reddit.com/prefs/apps and should
; not be altered unless you are defining your own developer application.
oauth_client_id = E2oEtRQfdfAfNQ
oauth_client_secret = praw_gapfill
oauth_redirect_uri = http://127.0.0.1:65000/
; Port that the rtv webserver will listen on. This should match the redirect
; uri defined above.
oauth_redirect_port = 65000
; Access permissions that will be requested.
oauth_scope = edit,history,identity,mysubreddits,privatemessages,read,report,save,submit,subscribe,vote

View File

@@ -56,8 +56,8 @@ def main():
data['copyright'] = rtv.__copyright__ data['copyright'] = rtv.__copyright__
# Escape dashes is all of the sections # Escape dashes is all of the sections
data = {k: v.replace('-', r'\-') for k, v in data.items()} data = {k: v.replace('-', r'\-') for k, v in data.items()}
print('Reading from %s/rtv/templates/rtv.1.template' % ROOT) print('Reading from %s/rtv/scripts/rtv.1.template' % ROOT)
with open(os.path.join(ROOT, 'rtv/templates/rtv.1.template')) as fp: with open(os.path.join(ROOT, 'rtv/scripts/rtv.1.template')) as fp:
template = fp.read() template = fp.read()
print('Populating template') print('Populating template')
out = template.format(**data) out = template.format(**data)

View File

@@ -5,7 +5,7 @@ import os
import codecs import codecs
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from rtv.config import Config from rtv.config import Config, copy_default_config, DEFAULT_CONFIG
try: try:
from unittest import mock from unittest import mock
@@ -13,6 +13,18 @@ except ImportError:
import mock import mock
def test_copy_default_config():
"Make sure the default config file was included in the package"
with NamedTemporaryFile(suffix='.cfg') as fp:
with mock.patch('rtv.config.input', return_value='y'):
copy_default_config(fp.name)
with open(DEFAULT_CONFIG) as fp_default:
assert fp.read() == fp_default.read()
permissions = os.stat(fp.name).st_mode & 0o777
assert permissions == 0o664
def test_config_interface(): def test_config_interface():
"Test setting and removing values" "Test setting and removing values"
@@ -20,40 +32,50 @@ def test_config_interface():
assert config['ascii'] is True assert config['ascii'] is True
config['ascii'] = False config['ascii'] = False
assert config['ascii'] is False assert config['ascii'] is False
config['ascii'] = True config['ascii'] = None
assert config['ascii'] is None
del config['ascii'] del config['ascii']
assert config['ascii'] is False assert config['ascii'] is False
config.update(subreddit='cfb', new_value=2.0) config.update(subreddit='cfb', new_value=2.0)
assert config['subreddit'] == 'cfb' assert config['subreddit'] == 'cfb'
assert config['new_value'] == 2.0 assert config['new_value'] == 2.0
assert config['link'] is None
assert config['log'] is None
def test_config_from_args():
def test_config_get_args():
"Ensure that command line arguments are parsed properly" "Ensure that command line arguments are parsed properly"
args = ['rtv', args = ['rtv',
'-s', 'cfb', '-s', 'cfb',
'-l', 'https://reddit.com/permalink •', '-l', 'https://reddit.com/permalink •',
'--log', 'logfile.log', '--log', 'logfile.log',
'--config', 'configfile.cfg',
'--ascii', '--ascii',
'--non-persistent', '--non-persistent',
'--clear-auth'] '--clear-auth',
'--copy-config']
with mock.patch('sys.argv', ['rtv']): with mock.patch('sys.argv', ['rtv']):
config = Config() config_dict = Config.get_args()
config.from_args() config = Config(**config_dict)
assert config.config == {} assert config.config == {}
with mock.patch('sys.argv', args): with mock.patch('sys.argv', args):
config = Config() config_dict = Config.get_args()
config.from_args()
config = Config(**config_dict)
assert config['ascii'] is True assert config['ascii'] is True
assert config['subreddit'] == 'cfb' assert config['subreddit'] == 'cfb'
assert config['link'] == 'https://reddit.com/permalink •'
assert config['log'] == 'logfile.log' assert config['log'] == 'logfile.log'
assert config['ascii'] is True assert config['ascii'] is True
assert config['persistent'] is False assert config['persistent'] is False
assert config['clear_auth'] is True assert config['clear_auth'] is True
assert config['link'] == 'https://reddit.com/permalink •'
assert config['config'] == 'configfile.cfg'
assert config['copy_config'] is True
def test_config_from_file(): def test_config_from_file():
@@ -68,15 +90,17 @@ def test_config_from_file():
'subreddit': 'cfb'} 'subreddit': 'cfb'}
with NamedTemporaryFile(suffix='.cfg') as fp: with NamedTemporaryFile(suffix='.cfg') as fp:
config = Config(config_file=fp.name)
config.from_file() fargs = Config.get_file(filename=fp.name)
config = Config(**fargs)
assert config.config == {} assert config.config == {}
rows = ['{0}={1}'.format(key, val) for key, val in args.items()] rows = ['{0}={1}'.format(key, val) for key, val in args.items()]
data = '\n'.join(['[rtv]'] + rows) data = '\n'.join(['[rtv]'] + rows)
fp.write(codecs.encode(data, 'utf-8')) fp.write(codecs.encode(data, 'utf-8'))
fp.flush() fp.flush()
config.from_file() fargs = Config.get_file(filename=fp.name)
config.update(**fargs)
assert config.config == args assert config.config == args