diff --git a/tests/conftest.py b/tests/conftest.py index f8c2d6b..d78c2de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,14 +4,16 @@ from __future__ import unicode_literals import os import curses import logging +import threading from functools import partial import praw import pytest from vcr import VCR from six.moves.urllib.parse import urlparse, parse_qs +from six.moves.BaseHTTPServer import HTTPServer -from rtv.oauth import OAuthHelper +from rtv.oauth import OAuthHelper, OAuthHandler from rtv.config import Config from rtv.terminal import Terminal from rtv.subreddit import SubredditPage @@ -196,6 +198,21 @@ def oauth(reddit, terminal, config): return OAuthHelper(reddit, terminal, config) +@pytest.yield_fixture() +def oauth_server(): + # Start the OAuth server on a random port in the background + server = HTTPServer(('', 0), OAuthHandler) + server.url = 'http://{0}:{1}/'.format(*server.server_address) + thread = threading.Thread(target=server.serve_forever) + thread.start() + try: + yield server + finally: + server.shutdown() + thread.join() + server.server_close() + + @pytest.fixture() def submission_page(reddit, terminal, config, oauth): submission = 'https://www.reddit.com/r/Python/comments/2xmo63' diff --git a/tests/test_content.py b/tests/test_content.py index e7c6452..bcc0d5f 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -6,13 +6,16 @@ from itertools import islice import six import praw -import mock import pytest from rtv.content import ( Content, SubmissionContent, SubredditContent, SubscriptionContent) from rtv import exceptions +try: + from unittest import mock +except ImportError: + import mock # Test entering a bunch of text into the prompt # (text, parsed subreddit, parsed order) diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 25332dc..f6f4f7c 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from tornado.web import Application -from tornado.testing import AsyncHTTPTestCase +import requests from praw.errors import OAuthException from rtv.oauth import OAuthHelper, OAuthHandler -from rtv.config import TEMPLATE + try: from unittest import mock @@ -14,38 +13,48 @@ except ImportError: import mock -class TestAuthHandler(AsyncHTTPTestCase): +def test_oauth_handler_not_found(oauth_server): - def get_app(self): - self.params = {} - handler = [('/', OAuthHandler, {'params': self.params})] - return Application(handler, template_path=TEMPLATE) + url = oauth_server.url + 'favicon.ico' + resp = requests.get(url) + assert resp.status_code == 404 - def test_no_callback(self): - resp = self.fetch('/') - assert resp.code == 200 - assert self.params['error'] is None - assert 'Wait...' in resp.body.decode() - def test_access_denied(self): - resp = self.fetch('/?error=access_denied') - assert resp.code == 200 - assert self.params['error'] == 'access_denied' - assert 'was denied access' in resp.body.decode() +def test_oauth_handler_no_callback(oauth_server): - def test_error(self): - resp = self.fetch('/?error=fake') - assert resp.code == 200 - assert self.params['error'] == 'fake' - assert 'fake' in resp.body.decode() + resp = requests.get(oauth_server.url) + assert resp.status_code == 200 + assert 'Wait...' in resp.text + assert OAuthHandler.params['error'] is None - def test_success(self): - resp = self.fetch('/?state=fake_state&code=fake_code') - assert resp.code == 200 - assert self.params['error'] is None - assert self.params['state'] == 'fake_state' - assert self.params['code'] == 'fake_code' - assert 'Access Granted' in resp.body.decode() + +def test_oauth_handler_access_denied(oauth_server): + + url = oauth_server.url + '?error=access_denied' + resp = requests.get(url) + assert resp.status_code == 200 + assert OAuthHandler.params['error'] == 'access_denied' + assert 'denied access' in resp.text + + +def test_oauth_handler_error(oauth_server): + + url = oauth_server.url + '?error=fake' + resp = requests.get(url) + assert resp.status_code == 200 + assert OAuthHandler.params['error'] == 'fake' + assert 'fake' in resp.text + + +def test_oauth_handler_success(oauth_server): + + url = oauth_server.url + '?state=fake_state&code=fake_code' + resp = requests.get(url) + assert resp.status_code == 200 + assert OAuthHandler.params['error'] is None + assert OAuthHandler.params['state'] == 'fake_state' + assert OAuthHandler.params['code'] == 'fake_code' + assert 'Access Granted' in resp.text def test_oauth_terminal_non_mobile_authorize(reddit, terminal, config): @@ -66,11 +75,11 @@ def test_oauth_terminal_mobile_authorize(reddit, terminal, config): assert '.compact' in oauth.reddit.config.API_PATHS['authorize'] -def test_oauth_authorize_with_refresh_token(oauth, stdscr, refresh_token): +def test_oauth_authorize_with_refresh_token(oauth, refresh_token): oauth.config.refresh_token = refresh_token oauth.authorize() - assert oauth.http_server is None + assert oauth.server is None # We should be able to handle an oauth failure with mock.patch.object(oauth.reddit, 'refresh_access_information'): @@ -78,7 +87,15 @@ def test_oauth_authorize_with_refresh_token(oauth, stdscr, refresh_token): oauth.reddit.refresh_access_information.side_effect = exception oauth.authorize() assert isinstance(oauth.term.loader.exception, OAuthException) - assert oauth.http_server is None + assert oauth.server is None + + +def test_oauth_clear_data(oauth): + oauth.config.refresh_token = 'secrettoken' + oauth.reddit.refresh_token = 'secrettoken' + oauth.clear_oauth_data() + assert oauth.config.refresh_token is None + assert oauth.reddit.refresh_token is None def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): @@ -87,34 +104,36 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): # function in the destination oauth module and not the helpers module with mock.patch('uuid.UUID.hex', new_callable=mock.PropertyMock) as uuid, \ mock.patch('rtv.terminal.Terminal.open_browser') as open_browser, \ - mock.patch('rtv.oauth.ioloop') as ioloop, \ - mock.patch('rtv.oauth.httpserver'), \ + mock.patch('rtv.oauth.HTTPServer') as http_server, \ mock.patch.object(oauth.reddit, 'user'), \ mock.patch('time.sleep'): - io = ioloop.IOLoop.current.return_value # Valid authorization oauth.term._display = False params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None} uuid.return_value = params['state'] - io.start.side_effect = lambda *_: oauth.params.update(**params) + + def serve_forever(): + oauth.params.update(**params) + http_server.return_value.serve_forever.side_effect = serve_forever oauth.authorize() - assert not open_browser.called + assert open_browser.called oauth.reddit.get_access_information.assert_called_with( reddit, params['code']) assert oauth.config.refresh_token is not None assert oauth.config.save_refresh_token.called + stdscr.reset_mock() oauth.reddit.get_access_information.reset_mock() oauth.config.save_refresh_token.reset_mock() - oauth.http_server = None + oauth.server = None # The next authorization should skip the oauth process oauth.config.refresh_token = refresh_token oauth.authorize() assert oauth.reddit.user is not None - assert oauth.http_server is None + assert oauth.server is None stdscr.reset_mock() # Invalid state returned @@ -129,7 +148,6 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): oauth.term._display = True params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None} uuid.return_value = params['state'] - io.start.side_effect = lambda *_: oauth.params.update(**params) oauth.authorize() assert open_browser.called @@ -137,11 +155,12 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): reddit, params['code']) assert oauth.config.refresh_token is not None assert oauth.config.save_refresh_token.called + stdscr.reset_mock() oauth.reddit.get_access_information.reset_mock() oauth.config.refresh_token = None oauth.config.save_refresh_token.reset_mock() - oauth.http_server = None + oauth.server = None # Exceptions when logging in are handled correctly with mock.patch.object(oauth.reddit, 'get_access_information'): @@ -149,13 +168,4 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): oauth.reddit.get_access_information.side_effect = exception oauth.authorize() assert isinstance(oauth.term.loader.exception, OAuthException) - assert not oauth.config.save_refresh_token.called - - -def test_oauth_clear_data(oauth): - - oauth.config.refresh_token = 'secrettoken' - oauth.reddit.refresh_token = 'secrettoken' - oauth.clear_oauth_data() - assert oauth.config.refresh_token is None - assert oauth.reddit.refresh_token is None \ No newline at end of file + assert not oauth.config.save_refresh_token.called \ No newline at end of file diff --git a/tests/test_terminal.py b/tests/test_terminal.py index 86cfff5..0d5701a 100644 --- a/tests/test_terminal.py +++ b/tests/test_terminal.py @@ -389,6 +389,7 @@ def test_open_link_subprocess(terminal): with mock.patch('time.sleep'), \ mock.patch('os.system'), \ + mock.patch('subprocess.Popen') as Popen, \ mock.patch('six.moves.input') as six_input, \ mock.patch.object(terminal, 'get_mailcap_entry'): @@ -398,6 +399,9 @@ def test_open_link_subprocess(terminal): six_input.reset_mock() os.system.reset_mock() terminal.stdscr.subwin.addstr.reset_mock() + Popen.return_value.communicate.return_value = '', 'stderr message' + Popen.return_value.poll.return_value = 0 + Popen.return_value.wait.return_value = 0 def get_error(): # Check if an error message was printed to the terminal @@ -415,6 +419,8 @@ def test_open_link_subprocess(terminal): # Non-blocking failure reset_mock() + Popen.return_value.poll.return_value = 127 + Popen.return_value.wait.return_value = 127 entry = ('fake .', 'fake %s') terminal.get_mailcap_entry.return_value = entry terminal.open_link(url) @@ -431,6 +437,8 @@ def test_open_link_subprocess(terminal): # needsterminal failure reset_mock() + Popen.return_value.poll.return_value = 127 + Popen.return_value.wait.return_value = 127 entry = ('fake .', 'fake %s; needsterminal') terminal.get_mailcap_entry.return_value = entry terminal.open_link(url) @@ -447,6 +455,8 @@ def test_open_link_subprocess(terminal): # copiousoutput failure reset_mock() + Popen.return_value.poll.return_value = 127 + Popen.return_value.wait.return_value = 127 entry = ('fake .', 'fake %s; needsterminal; copiousoutput') terminal.get_mailcap_entry.return_value = entry terminal.open_link(url)