Merge branch 'oauth2'
This commit is contained in:
36
README.rst
36
README.rst
@@ -46,6 +46,13 @@ The installation will place a script in the system path
|
||||
$ rtv
|
||||
$ rtv --help
|
||||
|
||||
If you're having issues running RTV with Python 2, run RTV as module :
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd /path/to/rtv
|
||||
$ python2 -m rtv
|
||||
|
||||
=====
|
||||
Usage
|
||||
=====
|
||||
@@ -152,19 +159,14 @@ Config File
|
||||
-----------
|
||||
|
||||
RTV will read a configuration placed at ``~/.config/rtv/rtv.cfg`` (or ``$XDG_CONFIG_HOME``).
|
||||
Each line in the file will replace the corresponding default argument in the launch script.
|
||||
Each line in the files will replace the corresponding default argument in the launch script.
|
||||
This can be used to avoid having to re-enter login credentials every time the program is launched.
|
||||
|
||||
The OAuth section contains a boolean to trigger auto-login (defaults to False).
|
||||
When authenticated, two additional fields are written : **access_token** and **refresh_token**.
|
||||
Those are basically like username and password : they are used to authenticate you on Reddit servers.
|
||||
|
||||
Example initial config:
|
||||
|
||||
.. code-block:: ini
|
||||
**rtv.cfg**
|
||||
|
||||
[oauth]
|
||||
auto_login=False
|
||||
.. code-block:: ini
|
||||
|
||||
[rtv]
|
||||
# Log file location
|
||||
@@ -180,6 +182,24 @@ Example initial config:
|
||||
# This may be necessary for compatibility with some terminal browsers
|
||||
# ascii=True
|
||||
|
||||
-----
|
||||
OAuth
|
||||
-----
|
||||
|
||||
OAuth is an authentication standard, that replaces authentication with login and password.
|
||||
|
||||
RTV implements OAuth. It stores OAuth configuration at ``~/.config/rtv/oauth.cfg``(or ``$XDG_CONFIG_HOME``).
|
||||
**The OAuth configuration file must be writable, and is created automatically if it doesn't exist.**
|
||||
It contains a boolean to trigger auto-login (defaults to false).
|
||||
When authenticated, an additional field is written : **refresh_token**.
|
||||
This acts as a replacement to username and password : it is used to authenticate you on Reddit servers.
|
||||
|
||||
Example **oauth.cfg**:
|
||||
|
||||
.. code-block:: ini
|
||||
|
||||
[oauth]
|
||||
auto_login=false
|
||||
|
||||
=========
|
||||
Changelog
|
||||
|
||||
@@ -22,7 +22,13 @@ from tornado import ioloop
|
||||
|
||||
__all__ = []
|
||||
|
||||
def get_config_fp():
|
||||
def load_rtv_config():
|
||||
"""
|
||||
Attempt to load saved settings for things like the username and password.
|
||||
"""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
|
||||
os.path.join(HOME, '.config'))
|
||||
@@ -35,30 +41,9 @@ def get_config_fp():
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
config.read(config_path)
|
||||
break
|
||||
|
||||
return config_path
|
||||
|
||||
def open_config():
|
||||
"""
|
||||
Search for a configuration file at the location ~/.rtv and attempt to load
|
||||
saved settings for things like the username and password.
|
||||
"""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
config_path = get_config_fp()
|
||||
config.read(config_path)
|
||||
|
||||
return config
|
||||
|
||||
def load_rtv_config():
|
||||
"""
|
||||
Attempt to load saved settings for things like the username and password.
|
||||
"""
|
||||
|
||||
config = open_config()
|
||||
|
||||
defaults = {}
|
||||
if config.has_section('rtv'):
|
||||
defaults = dict(config.items('rtv'))
|
||||
@@ -73,14 +58,26 @@ def load_oauth_config():
|
||||
Attempt to load saved OAuth settings
|
||||
"""
|
||||
|
||||
config = open_config()
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
|
||||
os.path.join(HOME, '.config'))
|
||||
|
||||
if os.path.exists(os.path.join(XDG_CONFIG_HOME, 'rtv')):
|
||||
config_path = os.path.join(XDG_CONFIG_HOME, 'rtv', 'oauth.cfg')
|
||||
else:
|
||||
config_path = os.path.join(HOME, '.rtv-oauth')
|
||||
|
||||
config.read(config_path)
|
||||
|
||||
if config.has_section('oauth'):
|
||||
defaults = dict(config.items('oauth'))
|
||||
else:
|
||||
# Populate OAuth section
|
||||
config['oauth'] = {'auto_login': False}
|
||||
with open(get_config_fp(), 'w') as cfg:
|
||||
config.add_section('oauth')
|
||||
config.set('oauth', 'auto_login', 'false')
|
||||
with open(config_path, 'w') as cfg:
|
||||
config.write(cfg)
|
||||
defaults = dict(config.items('oauth'))
|
||||
|
||||
@@ -106,7 +103,6 @@ def command_line():
|
||||
|
||||
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
|
||||
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
|
||||
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
|
||||
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -154,7 +150,7 @@ def main():
|
||||
reddit.config.decode_html_entities = False
|
||||
with curses_session() as stdscr:
|
||||
oauth = OAuthTool(reddit, stdscr, LoadScreen(stdscr))
|
||||
if args.auto_login == 'True': # Ew!
|
||||
if args.auto_login == 'true': # Ew!
|
||||
oauth.authorize()
|
||||
if args.link:
|
||||
page = SubmissionPage(stdscr, reddit, oauth, url=args.link)
|
||||
|
||||
@@ -18,8 +18,8 @@ given, the program will display a secure prompt to enter a password.
|
||||
"""
|
||||
|
||||
OAUTH = """\
|
||||
Authentication is now done by OAuth, since PRAW will stop supporting login with
|
||||
username and password soon.
|
||||
Authentication is now done by OAuth, since PRAW will drop
|
||||
password authentication soon.
|
||||
"""
|
||||
|
||||
CONTROLS = """
|
||||
|
||||
109
rtv/oauth.py
109
rtv/oauth.py
@@ -1,4 +1,3 @@
|
||||
from six.moves import configparser
|
||||
import curses
|
||||
import logging
|
||||
import os
|
||||
@@ -7,6 +6,7 @@ import uuid
|
||||
import webbrowser
|
||||
|
||||
import praw
|
||||
from six.moves import configparser
|
||||
|
||||
from . import config
|
||||
from .curses_helpers import show_notification, prompt_input
|
||||
@@ -16,8 +16,6 @@ from tornado import ioloop, web
|
||||
__all__ = ['token_validity', 'OAuthTool']
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
token_validity = 3540
|
||||
|
||||
oauth_state = None
|
||||
oauth_code = None
|
||||
oauth_error = None
|
||||
@@ -30,17 +28,18 @@ class HomeHandler(web.RequestHandler):
|
||||
class AuthHandler(web.RequestHandler):
|
||||
|
||||
def get(self):
|
||||
global oauth_state
|
||||
global oauth_code
|
||||
global oauth_error
|
||||
try:
|
||||
global oauth_state
|
||||
global oauth_code
|
||||
global oauth_error
|
||||
|
||||
oauth_state = self.get_argument('state', default='state_placeholder')
|
||||
oauth_code = self.get_argument('code', default='code_placeholder')
|
||||
oauth_error = self.get_argument('error', default='error_placeholder')
|
||||
oauth_state = self.get_argument('state', default='state_placeholder')
|
||||
oauth_code = self.get_argument('code', default='code_placeholder')
|
||||
oauth_error = self.get_argument('error', default='error_placeholder')
|
||||
|
||||
self.render('auth.html', state=oauth_state, code=oauth_code, error=oauth_error)
|
||||
|
||||
ioloop.IOLoop.current().stop()
|
||||
self.render('auth.html', state=oauth_state, code=oauth_code, error=oauth_error)
|
||||
finally:
|
||||
ioloop.IOLoop.current().stop()
|
||||
|
||||
class OAuthTool(object):
|
||||
|
||||
@@ -59,34 +58,28 @@ class OAuthTool(object):
|
||||
self.redirect_uri = redirect_uri or config.oauth_redirect_uri
|
||||
|
||||
self.scope = scope or config.oauth_scope.split('-')
|
||||
|
||||
self.access_info = {}
|
||||
|
||||
self.token_expiration = 0
|
||||
# Terminal web browser
|
||||
self.compact = os.environ.get('BROWSER') in ['w3m', 'links', 'elinks', 'lynx']
|
||||
|
||||
# Initialize Tornado webapp and listen on port 65000
|
||||
# Initialize Tornado webapp
|
||||
self.callback_app = web.Application([
|
||||
(r'/', HomeHandler),
|
||||
(r'/auth', AuthHandler),
|
||||
], template_path='rtv/templates')
|
||||
self.callback_app.listen(65000)
|
||||
|
||||
def get_config_fp(self):
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
|
||||
os.path.join(HOME, '.config'))
|
||||
|
||||
config_paths = [
|
||||
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||
os.path.join(HOME, '.rtv')
|
||||
]
|
||||
if os.path.exists(os.path.join(XDG_CONFIG_HOME, 'rtv')):
|
||||
file_path = os.path.join(XDG_CONFIG_HOME, 'rtv', 'oauth.cfg')
|
||||
else:
|
||||
file_path = os.path.join(HOME, '.rtv-oauth')
|
||||
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
break
|
||||
|
||||
return config_path
|
||||
return file_path
|
||||
|
||||
def open_config(self, update=False):
|
||||
if self.config_fp is None:
|
||||
@@ -100,48 +93,42 @@ class OAuthTool(object):
|
||||
with open(self.config_fp, 'w') as cfg:
|
||||
self.config.write(cfg)
|
||||
|
||||
def set_token_expiration(self):
|
||||
self.token_expiration = time.time() + token_validity
|
||||
|
||||
def token_expired(self):
|
||||
return time.time() > self.token_expiration
|
||||
|
||||
def refresh(self, force=False):
|
||||
if self.token_expired() or force:
|
||||
try:
|
||||
with self.loader(message='Refreshing token'):
|
||||
new_access_info = self.reddit.refresh_access_information(
|
||||
self.config['oauth']['refresh_token'])
|
||||
self.access_info = new_access_info
|
||||
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
|
||||
praw.errors.HTTPException) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.save_config()
|
||||
def clear_oauth_data(self):
|
||||
self.open_config(update=True)
|
||||
if self.config.has_section('oauth') and self.config.has_option('oauth', 'refresh_token'):
|
||||
self.config.remove_option('oauth', 'refresh_token')
|
||||
self.save_config()
|
||||
|
||||
def authorize(self):
|
||||
if self.compact and not '.compact' in self.reddit.config.API_PATHS['authorize']:
|
||||
self.reddit.config.API_PATHS['authorize'] += '.compact'
|
||||
|
||||
self.reddit.set_oauth_app_info(self.client_id,
|
||||
self.client_secret,
|
||||
self.redirect_uri)
|
||||
|
||||
self.open_config(update=True)
|
||||
# If no previous OAuth data found, starting from scratch
|
||||
if 'oauth' not in self.config or 'access_token' not in self.config['oauth']:
|
||||
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'refresh_token'):
|
||||
# Start HTTP server and listen on port 65000
|
||||
self.callback_app.listen(65000)
|
||||
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
|
||||
permission_ask_page_link = self.reddit.get_authorize_url(str(hex_uuid),
|
||||
scope=self.scope, refreshable=True)
|
||||
|
||||
with self.loader(message='Waiting for authorization'):
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
if self.compact:
|
||||
show_notification(self.stdscr, ['Opening ' + os.environ.get('BROWSER')])
|
||||
curses.endwin()
|
||||
webbrowser.open_new_tab(permission_ask_page_link)
|
||||
ioloop.IOLoop.current().start()
|
||||
curses.doupdate()
|
||||
else:
|
||||
with self.loader(message='Waiting for authorization'):
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
ioloop.IOLoop.current().start()
|
||||
|
||||
global oauth_state
|
||||
global oauth_code
|
||||
@@ -169,21 +156,15 @@ class OAuthTool(object):
|
||||
with self.loader(message='Logging in'):
|
||||
# Get access information (tokens and scopes)
|
||||
self.access_info = self.reddit.get_access_information(self.final_code)
|
||||
|
||||
self.reddit.set_access_credentials(
|
||||
scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
if 'oauth' not in self.config:
|
||||
self.config['oauth'] = {}
|
||||
if not self.config.has_section('oauth'):
|
||||
self.config.add_section('oauth')
|
||||
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
||||
self.save_config()
|
||||
# Otherwise, fetch new access token
|
||||
else:
|
||||
self.refresh(force=True)
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.refresh_access_information(self.config.get('oauth', 'refresh_token'))
|
||||
|
||||
16
rtv/page.py
16
rtv/page.py
@@ -314,9 +314,6 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('a')
|
||||
def upvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -332,9 +329,6 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('z')
|
||||
def downvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -358,6 +352,7 @@ class BasePage(object):
|
||||
|
||||
if self.reddit.is_oauth_session():
|
||||
self.reddit.clear_authentication()
|
||||
self.oauth.clear_oauth_data()
|
||||
return
|
||||
|
||||
self.oauth.authorize()
|
||||
@@ -372,9 +367,6 @@ class BasePage(object):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -403,9 +395,6 @@ class BasePage(object):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -440,9 +429,6 @@ class BasePage(object):
|
||||
Checks the inbox for unread messages and displays a notification.
|
||||
"""
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
inbox = len(list(self.reddit.get_unread(limit=1)))
|
||||
try:
|
||||
if inbox > 0:
|
||||
|
||||
@@ -93,9 +93,6 @@ class SubmissionPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data['type'] == 'Submission':
|
||||
content = data['text']
|
||||
@@ -131,9 +128,6 @@ class SubmissionPage(BasePage):
|
||||
def delete_comment(self):
|
||||
"Delete a comment as long as it is not the current submission"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
if self.nav.absolute_index != -1:
|
||||
self.delete()
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from .exceptions import SubredditError, AccountError
|
||||
from .page import BasePage, Navigator, BaseController
|
||||
from .submission import SubmissionPage
|
||||
from .subscriptions import SubscriptionPage
|
||||
from .subscription import SubscriptionPage
|
||||
from .content import SubredditContent
|
||||
from .helpers import open_browser, open_editor, strip_subreddit_url
|
||||
from .docs import SUBMISSION_FILE
|
||||
@@ -54,9 +54,6 @@ class SubredditPage(BasePage):
|
||||
def refresh_content(self, name=None, order=None):
|
||||
"Re-download all submissions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
name = name or self.content.name
|
||||
order = order or self.content.order
|
||||
|
||||
@@ -136,9 +133,6 @@ class SubredditPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Strips the subreddit to just the name
|
||||
# Make sure it is a valid subreddit for submission
|
||||
subreddit = self.reddit.get_subreddit(self.content.name)
|
||||
@@ -180,9 +174,6 @@ class SubredditPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Open subscriptions page
|
||||
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
|
||||
page.loop()
|
||||
|
||||
@@ -38,9 +38,6 @@ class SubscriptionPage(BasePage):
|
||||
def refresh_content(self):
|
||||
"Re-download all subscriptions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
|
||||
self.nav = Navigator(self.content.get)
|
||||
|
||||
Reference in New Issue
Block a user