Config now loads default values from a file alongside the source.
This commit is contained in:
154
rtv/config.py
154
rtv/config.py
@@ -3,20 +3,23 @@ from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import shutil
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
from six.moves import configparser
|
||||
from six.moves import configparser, input
|
||||
|
||||
from . import docs, __version__
|
||||
|
||||
|
||||
HOME = os.path.expanduser('~')
|
||||
PACKAGE = os.path.dirname(__file__)
|
||||
HOME = os.path.expanduser('~')
|
||||
TEMPLATE = os.path.join(PACKAGE, 'templates')
|
||||
DEFAULT_CONFIG = os.path.join(PACKAGE, 'rtv.cfg')
|
||||
XDG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config'))
|
||||
CONFIG = os.path.join(XDG_HOME, 'rtv', 'rtv.cfg')
|
||||
TOKEN = os.path.join(XDG_HOME, 'rtv', 'refresh-token')
|
||||
HISTORY = os.path.join(XDG_HOME, 'rtv', 'history.log')
|
||||
TEMPLATE = os.path.join(PACKAGE, 'templates')
|
||||
|
||||
|
||||
def build_parser():
|
||||
@@ -33,22 +36,49 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
'-l', dest='link',
|
||||
help='full URL of a submission that will be opened on start')
|
||||
parser.add_argument(
|
||||
'--log', metavar='FILE', action='store',
|
||||
help='log HTTP requests')
|
||||
parser.add_argument(
|
||||
'--config', metavar='FILE', action='store',
|
||||
help='Load configuration settings')
|
||||
parser.add_argument(
|
||||
'--ascii', action='store_const', const=True,
|
||||
help='enable ascii-only mode')
|
||||
parser.add_argument(
|
||||
'--log', metavar='FILE', action='store',
|
||||
help='log HTTP requests to a file')
|
||||
parser.add_argument(
|
||||
'--non-persistent', dest='persistent', action='store_const',
|
||||
const=False,
|
||||
help='Forget all authenticated users when the program exits')
|
||||
help='Forget the authenticated user when the program exits')
|
||||
parser.add_argument(
|
||||
'--clear-auth', dest='clear_auth', action='store_const', const=True,
|
||||
help='Remove any saved OAuth tokens before starting')
|
||||
help='Remove any saved user data before launching')
|
||||
parser.add_argument(
|
||||
'--copy-config', dest='copy_config', action='store_const', const=True,
|
||||
help='Copy the default configuration to {HOME}/.config/rtv/rtv.cfg')
|
||||
return parser
|
||||
|
||||
|
||||
def copy_default_config(filename=CONFIG):
|
||||
"""
|
||||
Copy the default configuration file to the user's {HOME}/.config/rtv
|
||||
"""
|
||||
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
ch = input('File %s already exists, overwrite? y/[n]):' % filename)
|
||||
if ch not in ('Y', 'y'):
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
|
||||
filepath = os.path.dirname(filename)
|
||||
if not os.path.exists(filepath):
|
||||
os.makedirs(filepath)
|
||||
|
||||
print('Copying default settings to %s' % filename)
|
||||
shutil.copy(DEFAULT_CONFIG, filename)
|
||||
|
||||
|
||||
class OrderedSet(object):
|
||||
"""
|
||||
A simple implementation of an ordered set. A set is used to check
|
||||
@@ -76,36 +106,12 @@ class OrderedSet(object):
|
||||
|
||||
class Config(object):
|
||||
|
||||
DEFAULT = {
|
||||
'ascii': False,
|
||||
'persistent': True,
|
||||
'clear_auth': False,
|
||||
'log': None,
|
||||
'link': None,
|
||||
'subreddit': 'front',
|
||||
'history_size': 200,
|
||||
# https://github.com/reddit/reddit/wiki/OAuth2
|
||||
# Client ID is of type "installed app" and the secret should be empty
|
||||
'oauth_client_id': 'E2oEtRQfdfAfNQ',
|
||||
'oauth_client_secret': 'praw_gapfill',
|
||||
'oauth_redirect_uri': 'http://127.0.0.1:65000/',
|
||||
'oauth_redirect_port': 65000,
|
||||
'oauth_scope': [
|
||||
'edit', 'history', 'identity', 'mysubreddits', 'privatemessages',
|
||||
'read', 'report', 'save', 'submit', 'subscribe', 'vote'],
|
||||
'template_path': TEMPLATE,
|
||||
}
|
||||
def __init__(self, history_file=HISTORY, token_file=TOKEN, **kwargs):
|
||||
|
||||
def __init__(self,
|
||||
config_file=CONFIG,
|
||||
history_file=HISTORY,
|
||||
token_file=TOKEN,
|
||||
**kwargs):
|
||||
|
||||
self.config_file = config_file
|
||||
self.history_file = history_file
|
||||
self.token_file = token_file
|
||||
self.config = kwargs
|
||||
self.default = self.get_file(DEFAULT_CONFIG)
|
||||
|
||||
# `refresh_token` and `history` are saved/loaded at separate locations,
|
||||
# so they are treated differently from the rest of the config options.
|
||||
@@ -113,7 +119,10 @@ class Config(object):
|
||||
self.history = OrderedSet()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config.get(item, self.DEFAULT.get(item))
|
||||
if item in self.config:
|
||||
return self.config[item]
|
||||
else:
|
||||
return self.default.get(item, None)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.config[key] = value
|
||||
@@ -124,33 +133,6 @@ class Config(object):
|
||||
def update(self, **kwargs):
|
||||
self.config.update(kwargs)
|
||||
|
||||
def from_args(self):
|
||||
parser = build_parser()
|
||||
args = vars(parser.parse_args())
|
||||
# Filter out argument values that weren't supplied
|
||||
args = {key: val for key, val in args.items() if val is not None}
|
||||
self.update(**args)
|
||||
|
||||
def from_file(self):
|
||||
config = configparser.ConfigParser()
|
||||
if os.path.exists(self.config_file):
|
||||
with codecs.open(self.config_file, encoding='utf-8') as fp:
|
||||
config.readfp(fp)
|
||||
|
||||
config_dict = {}
|
||||
if config.has_section('rtv'):
|
||||
config_dict = dict(config.items('rtv'))
|
||||
|
||||
# Convert 'true'/'false' to boolean True/False
|
||||
if 'ascii' in config_dict:
|
||||
config_dict['ascii'] = config.getboolean('rtv', 'ascii')
|
||||
if 'clear_auth' in config_dict:
|
||||
config_dict['clear_auth'] = config.getboolean('rtv', 'clear_auth')
|
||||
if 'persistent' in config_dict:
|
||||
config_dict['persistent'] = config.getboolean('rtv', 'persistent')
|
||||
|
||||
self.update(**config_dict)
|
||||
|
||||
def load_refresh_token(self):
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file) as fp:
|
||||
@@ -185,6 +167,54 @@ class Config(object):
|
||||
os.remove(self.history_file)
|
||||
self.history = OrderedSet()
|
||||
|
||||
@staticmethod
|
||||
def get_args():
|
||||
"""
|
||||
Load settings from the command line.
|
||||
"""
|
||||
|
||||
parser = build_parser()
|
||||
args = vars(parser.parse_args())
|
||||
# Filter out argument values that weren't supplied
|
||||
return {key: val for key, val in args.items() if val is not None}
|
||||
|
||||
@classmethod
|
||||
def get_file(cls, filename=None):
|
||||
"""
|
||||
Load settings from an rtv configuration file.
|
||||
"""
|
||||
|
||||
if filename is None:
|
||||
filename = CONFIG
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
if os.path.exists(filename):
|
||||
with codecs.open(filename, encoding='utf-8') as fp:
|
||||
config.readfp(fp)
|
||||
|
||||
return cls._parse_rtv_file(config)
|
||||
|
||||
@staticmethod
|
||||
def _parse_rtv_file(config):
|
||||
|
||||
out = {}
|
||||
if config.has_section('rtv'):
|
||||
out = dict(config.items('rtv'))
|
||||
|
||||
params = {
|
||||
'ascii': partial(config.getboolean, 'rtv'),
|
||||
'clear_auth': partial(config.getboolean, 'rtv'),
|
||||
'persistent': partial(config.getboolean, 'rtv'),
|
||||
'history_size': partial(config.getint, 'rtv'),
|
||||
'oauth_redirect_port': partial(config.getint, 'rtv'),
|
||||
'oauth_scope': lambda x: out[x].split(',')
|
||||
}
|
||||
|
||||
for key, func in params.items():
|
||||
if key in out:
|
||||
out[key] = func(key)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _ensure_filepath(filename):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user