OAuth authentication
This commit is contained in:
@@ -7,7 +7,7 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
import praw
|
import praw
|
||||||
import praw.errors
|
import praw.errors
|
||||||
from six.moves import configparser
|
import configparser
|
||||||
|
|
||||||
from . import config
|
from . import config
|
||||||
from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError
|
from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError
|
||||||
@@ -15,6 +15,7 @@ from .curses_helpers import curses_session
|
|||||||
from .submission import SubmissionPage
|
from .submission import SubmissionPage
|
||||||
from .subreddit import SubredditPage
|
from .subreddit import SubredditPage
|
||||||
from .docs import *
|
from .docs import *
|
||||||
|
from .oauth import load_oauth_config, read_setting, write_setting, authorize
|
||||||
from .__version__ import __version__
|
from .__version__ import __version__
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@@ -106,9 +107,25 @@ def main():
|
|||||||
print('Connecting...')
|
print('Connecting...')
|
||||||
reddit = praw.Reddit(user_agent=AGENT)
|
reddit = praw.Reddit(user_agent=AGENT)
|
||||||
reddit.config.decode_html_entities = False
|
reddit.config.decode_html_entities = False
|
||||||
if args.username:
|
if read_setting(key="authorization_token") is None:
|
||||||
|
print('Hello OAuth login helper!')
|
||||||
|
authorize(reddit)
|
||||||
|
else:
|
||||||
|
oauth_config = load_oauth_config()
|
||||||
|
oauth_data = {}
|
||||||
|
if oauth_config.has_section('oauth'):
|
||||||
|
oauth_data = dict(oauth_config.items('oauth'))
|
||||||
|
|
||||||
|
reddit.set_oauth_app_info(oauth_data['client_id'],
|
||||||
|
oauth_data['client_secret'],
|
||||||
|
oauth_data['redirect_uri'])
|
||||||
|
|
||||||
|
reddit.set_access_credentials(scope=set(oauth_data['scope'].split('-')),
|
||||||
|
access_token=oauth_data['authorization_token'],
|
||||||
|
refresh_token=oauth_data['refresh_token'])
|
||||||
|
"""if args.username:
|
||||||
# PRAW will prompt for password if it is None
|
# PRAW will prompt for password if it is None
|
||||||
reddit.login(args.username, args.password)
|
reddit.login(args.username, args.password)"""
|
||||||
with curses_session() as stdscr:
|
with curses_session() as stdscr:
|
||||||
if args.link:
|
if args.link:
|
||||||
page = SubmissionPage(stdscr, reddit, url=args.link)
|
page = SubmissionPage(stdscr, reddit, url=args.link)
|
||||||
@@ -133,6 +150,7 @@ def main():
|
|||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# Ensure sockets are closed to prevent a ResourceWarning
|
# Ensure sockets are closed to prevent a ResourceWarning
|
||||||
|
print(reddit.is_oauth_session())
|
||||||
reddit.handler.http.close()
|
reddit.handler.http.close()
|
||||||
|
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|||||||
87
rtv/oauth.py
Normal file
87
rtv/oauth.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import configparser
|
||||||
|
import os
|
||||||
|
import webbrowser
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
def get_config_file_path():
|
||||||
|
HOME = os.path.expanduser('~')
|
||||||
|
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config'))
|
||||||
|
config_paths = [
|
||||||
|
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||||
|
os.path.join(HOME, '.rtv')
|
||||||
|
]
|
||||||
|
|
||||||
|
# get the first existing config file
|
||||||
|
for config_path in config_paths:
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
break
|
||||||
|
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
def load_oauth_config():
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config_path = get_config_file_path()
|
||||||
|
config.read(config_path)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def read_setting(key, section='oauth'):
|
||||||
|
config = load_oauth_config()
|
||||||
|
|
||||||
|
try:
|
||||||
|
setting = config[section][key]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return setting
|
||||||
|
|
||||||
|
def write_setting(key, value, section='oauth'):
|
||||||
|
config = load_oauth_config()
|
||||||
|
|
||||||
|
config[section][key] = value
|
||||||
|
with open(config_path, 'w') as cfg_file:
|
||||||
|
config.write(cfg_file)
|
||||||
|
|
||||||
|
def authorize(reddit):
|
||||||
|
config = load_oauth_config()
|
||||||
|
|
||||||
|
settings = {}
|
||||||
|
if config.has_section('oauth'):
|
||||||
|
settings = dict(config.items('oauth'))
|
||||||
|
|
||||||
|
scopes = ["edit", "history", "identity", "mysubreddits", "privatemessages", "read", "report", "save", "submit", "subscribe", "vote"]
|
||||||
|
|
||||||
|
reddit.set_oauth_app_info(settings['client_id'],
|
||||||
|
settings['client_secret'],
|
||||||
|
settings['redirect_uri'])
|
||||||
|
|
||||||
|
# Generate a random UUID
|
||||||
|
hex_uuid = uuid.uuid4().hex
|
||||||
|
|
||||||
|
permission_ask_page_link = reddit.get_authorize_url(str(hex_uuid), scope=scopes, refreshable=True)
|
||||||
|
input("You will now be redirected to your web browser. Press Enter to continue.")
|
||||||
|
webbrowser.open(permission_ask_page_link)
|
||||||
|
|
||||||
|
print("After allowing rtv app access, you will land on a page giving you a state and a code string. Please enter them here.")
|
||||||
|
final_state = input("State : ")
|
||||||
|
final_code = input("Code : ")
|
||||||
|
|
||||||
|
# Check if UUID matches obtained state
|
||||||
|
# (if not, authorization process is compromised, and I'm giving up)
|
||||||
|
if hex_uuid == final_state:
|
||||||
|
print("Obtained state matches UUID")
|
||||||
|
else:
|
||||||
|
print("Obtained state does not match UUID, stopping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get access information (authorization token)
|
||||||
|
info = reddit.get_access_information(final_code)
|
||||||
|
config['oauth']['authorization_token'] = info['access_token']
|
||||||
|
config['oauth']['refresh_token'] = info['refresh_token']
|
||||||
|
config['oauth']['scope'] = '-'.join(info['scope'])
|
||||||
|
|
||||||
|
config_path = get_config_file_path()
|
||||||
|
with open(config_path, 'w') as cfg_file:
|
||||||
|
config.write(cfg_file)
|
||||||
@@ -169,7 +169,7 @@ class SubredditPage(BasePage):
|
|||||||
def open_subscriptions(self):
|
def open_subscriptions(self):
|
||||||
"Open user subscriptions page"
|
"Open user subscriptions page"
|
||||||
|
|
||||||
if not self.reddit.is_logged_in():
|
if not self.reddit.is_logged_in() and not self.reddit.is_oauth_session():
|
||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user