diff --git a/rtv/docs.py b/rtv/docs.py index da1a273..6bdf31a 100644 --- a/rtv/docs.py +++ b/rtv/docs.py @@ -107,7 +107,7 @@ SUBMISSION_EDIT_FILE = """{content} OAUTH_ACCESS_DENIED = """\
Reddit Terminal Viewer was - denied access and will continue to operate in unauthenticated mode + denied access and will continue to operate in unauthenticated mode, you can close this window.
""" diff --git a/rtv/oauth.py b/rtv/oauth.py index 8147fb3..a09fccc 100644 --- a/rtv/oauth.py +++ b/rtv/oauth.py @@ -23,9 +23,23 @@ INDEX = os.path.join(TEMPLATES, 'index.html') class OAuthHandler(BaseHTTPRequestHandler): + # params are stored as a global because we don't have control over what + # gets passed into the handler __init__. These will be accessed by the + # OAuthHelper class. params = {'state': None, 'code': None, 'error': None} + shutdown_on_request = True def do_GET(self): + """ + Accepts GET requests to http://localhost:6500/, and stores the query + params in the global dict. If shutdown_on_request is true, stop the + server after the first successful request. + + The http request may contain the following query params: + - state : unique identifier, should match what we passed to reddit + - code : code that can be exchanged for a refresh token + - error : if provided, the OAuth error that occurred + """ parsed_path = urlparse(self.path) if parsed_path.path != '/': @@ -46,13 +60,27 @@ class OAuthHandler(BaseHTTPRequestHandler): self.wfile.write(body) - # Shutdown the server after serving the request - # http://stackoverflow.com/a/22533929 - thread = threading.Thread(target=self.server.shutdown) - thread.daemon = True - thread.start() + if self.shutdown_on_request: + # Shutdown the server after serving the request + # http://stackoverflow.com/a/22533929 + thread = threading.Thread(target=self.server.shutdown) + thread.daemon = True + thread.start() + + def log_message(self, format, *args): + """ + Redirect logging to our own handler instead of stdout + """ + _logger.debug(format, *args) def build_body(self, template_file=INDEX): + """ + Params: + template_file (text): Path to an index.html template + + Returns: + body (bytes): THe utf-8 encoded document body + """ if self.params['error'] == 'access_denied': message = docs.OAUTH_ACCESS_DENIED @@ -70,9 +98,6 @@ class OAuthHandler(BaseHTTPRequestHandler): body = codecs.encode(body, 'utf-8') return body - def log_message(self, format, *args): - _logger.debug(format, *args) - class OAuthHelper(object): @@ -84,8 +109,9 @@ class OAuthHelper(object): self.reddit = reddit self.config = config - address = ('', self.config['oauth_redirect_port']) - self.server = HTTPServer(address, OAuthHandler) + # Wait to initialize the server, we don't want to reserve the port + # unless we know that the server needs to be used. + self.server = None self.reddit.set_oauth_app_info( self.config['oauth_client_id'], @@ -112,28 +138,50 @@ class OAuthHelper(object): authorize_url = self.reddit.get_authorize_url( state, scope=self.config['oauth_scope'], refreshable=True) + if self.server is None: + address = ('', self.config['oauth_redirect_port']) + self.server = HTTPServer(address, OAuthHandler) + if self.term.display: # Open a background browser (e.g. firefox) which is non-blocking. - # Stop the iloop when the user hits the auth callback, at which - # point we continue and check the callback params. + # The server will block until it responds to its first request, + # at which point we can check the callback params. + OAuthHandler.shutdown_on_request = True with self.term.loader('Opening browser for authorization'): self.term.open_browser(authorize_url) self.server.serve_forever() if self.term.loader.exception: + # Don't need to call server.shutdown() because serve_forever() + # is wrapped in a try-finally that doees it for us. return else: # Open the terminal webbrowser in a background thread and wait # while for the user to close the process. Once the process is # closed, the iloop is stopped and we can check if the user has # hit the callback URL. + OAuthHandler.shutdown_on_request = False with self.term.loader('Redirecting to reddit', delay=0): # This load message exists to provide user feedback time.sleep(1) - io.add_callback(self._async_open_browser, authorize_url) - io.start() + + thread = threading.Thread(target=self.server.serve_forever) + thread.daemon = True + thread.start() + try: + self.term.open_browser(authorize_url) + except Exception as e: + # If an exception is raised it will be seen by the thread + # so we don't need to explicitly shutdown() the server + _logger.exception(e) + self.term.show_notification('Browser Error') + else: + _logger.debug('Calling server shutdown()') + self.server.shutdown() + finally: + thread.join() if self.params['error'] == 'access_denied': - self.term.show_notification('Declined access') + self.term.show_notification('Denied access') return elif self.params['error']: self.term.show_notification('Authentication error')