Updating tests.

This commit is contained in:
Michael Lazar
2016-08-09 18:26:48 -07:00
parent d4cab22ffe
commit c096d7014c
4 changed files with 94 additions and 54 deletions

View File

@@ -4,14 +4,16 @@ from __future__ import unicode_literals
import os import os
import curses import curses
import logging import logging
import threading
from functools import partial from functools import partial
import praw import praw
import pytest import pytest
from vcr import VCR from vcr import VCR
from six.moves.urllib.parse import urlparse, parse_qs from six.moves.urllib.parse import urlparse, parse_qs
from six.moves.BaseHTTPServer import HTTPServer
from rtv.oauth import OAuthHelper from rtv.oauth import OAuthHelper, OAuthHandler
from rtv.config import Config from rtv.config import Config
from rtv.terminal import Terminal from rtv.terminal import Terminal
from rtv.subreddit import SubredditPage from rtv.subreddit import SubredditPage
@@ -196,6 +198,21 @@ def oauth(reddit, terminal, config):
return OAuthHelper(reddit, terminal, config) return OAuthHelper(reddit, terminal, config)
@pytest.yield_fixture()
def oauth_server():
# Start the OAuth server on a random port in the background
server = HTTPServer(('', 0), OAuthHandler)
server.url = 'http://{0}:{1}/'.format(*server.server_address)
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
yield server
finally:
server.shutdown()
thread.join()
server.server_close()
@pytest.fixture() @pytest.fixture()
def submission_page(reddit, terminal, config, oauth): def submission_page(reddit, terminal, config, oauth):
submission = 'https://www.reddit.com/r/Python/comments/2xmo63' submission = 'https://www.reddit.com/r/Python/comments/2xmo63'

View File

@@ -6,13 +6,16 @@ from itertools import islice
import six import six
import praw import praw
import mock
import pytest import pytest
from rtv.content import ( from rtv.content import (
Content, SubmissionContent, SubredditContent, SubscriptionContent) Content, SubmissionContent, SubredditContent, SubscriptionContent)
from rtv import exceptions from rtv import exceptions
try:
from unittest import mock
except ImportError:
import mock
# Test entering a bunch of text into the prompt # Test entering a bunch of text into the prompt
# (text, parsed subreddit, parsed order) # (text, parsed subreddit, parsed order)

View File

@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from tornado.web import Application import requests
from tornado.testing import AsyncHTTPTestCase
from praw.errors import OAuthException from praw.errors import OAuthException
from rtv.oauth import OAuthHelper, OAuthHandler from rtv.oauth import OAuthHelper, OAuthHandler
from rtv.config import TEMPLATE
try: try:
from unittest import mock from unittest import mock
@@ -14,38 +13,48 @@ except ImportError:
import mock import mock
class TestAuthHandler(AsyncHTTPTestCase): def test_oauth_handler_not_found(oauth_server):
def get_app(self): url = oauth_server.url + 'favicon.ico'
self.params = {} resp = requests.get(url)
handler = [('/', OAuthHandler, {'params': self.params})] assert resp.status_code == 404
return Application(handler, template_path=TEMPLATE)
def test_no_callback(self):
resp = self.fetch('/')
assert resp.code == 200
assert self.params['error'] is None
assert 'Wait...' in resp.body.decode()
def test_access_denied(self): def test_oauth_handler_no_callback(oauth_server):
resp = self.fetch('/?error=access_denied')
assert resp.code == 200
assert self.params['error'] == 'access_denied'
assert 'was denied access' in resp.body.decode()
def test_error(self): resp = requests.get(oauth_server.url)
resp = self.fetch('/?error=fake') assert resp.status_code == 200
assert resp.code == 200 assert 'Wait...' in resp.text
assert self.params['error'] == 'fake' assert OAuthHandler.params['error'] is None
assert 'fake' in resp.body.decode()
def test_success(self):
resp = self.fetch('/?state=fake_state&code=fake_code') def test_oauth_handler_access_denied(oauth_server):
assert resp.code == 200
assert self.params['error'] is None url = oauth_server.url + '?error=access_denied'
assert self.params['state'] == 'fake_state' resp = requests.get(url)
assert self.params['code'] == 'fake_code' assert resp.status_code == 200
assert 'Access Granted' in resp.body.decode() assert OAuthHandler.params['error'] == 'access_denied'
assert 'denied access' in resp.text
def test_oauth_handler_error(oauth_server):
url = oauth_server.url + '?error=fake'
resp = requests.get(url)
assert resp.status_code == 200
assert OAuthHandler.params['error'] == 'fake'
assert 'fake' in resp.text
def test_oauth_handler_success(oauth_server):
url = oauth_server.url + '?state=fake_state&code=fake_code'
resp = requests.get(url)
assert resp.status_code == 200
assert OAuthHandler.params['error'] is None
assert OAuthHandler.params['state'] == 'fake_state'
assert OAuthHandler.params['code'] == 'fake_code'
assert 'Access Granted' in resp.text
def test_oauth_terminal_non_mobile_authorize(reddit, terminal, config): def test_oauth_terminal_non_mobile_authorize(reddit, terminal, config):
@@ -66,11 +75,11 @@ def test_oauth_terminal_mobile_authorize(reddit, terminal, config):
assert '.compact' in oauth.reddit.config.API_PATHS['authorize'] assert '.compact' in oauth.reddit.config.API_PATHS['authorize']
def test_oauth_authorize_with_refresh_token(oauth, stdscr, refresh_token): def test_oauth_authorize_with_refresh_token(oauth, refresh_token):
oauth.config.refresh_token = refresh_token oauth.config.refresh_token = refresh_token
oauth.authorize() oauth.authorize()
assert oauth.http_server is None assert oauth.server is None
# We should be able to handle an oauth failure # We should be able to handle an oauth failure
with mock.patch.object(oauth.reddit, 'refresh_access_information'): with mock.patch.object(oauth.reddit, 'refresh_access_information'):
@@ -78,7 +87,15 @@ def test_oauth_authorize_with_refresh_token(oauth, stdscr, refresh_token):
oauth.reddit.refresh_access_information.side_effect = exception oauth.reddit.refresh_access_information.side_effect = exception
oauth.authorize() oauth.authorize()
assert isinstance(oauth.term.loader.exception, OAuthException) assert isinstance(oauth.term.loader.exception, OAuthException)
assert oauth.http_server is None assert oauth.server is None
def test_oauth_clear_data(oauth):
oauth.config.refresh_token = 'secrettoken'
oauth.reddit.refresh_token = 'secrettoken'
oauth.clear_oauth_data()
assert oauth.config.refresh_token is None
assert oauth.reddit.refresh_token is None
def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): def test_oauth_authorize(oauth, reddit, stdscr, refresh_token):
@@ -87,34 +104,36 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token):
# function in the destination oauth module and not the helpers module # function in the destination oauth module and not the helpers module
with mock.patch('uuid.UUID.hex', new_callable=mock.PropertyMock) as uuid, \ with mock.patch('uuid.UUID.hex', new_callable=mock.PropertyMock) as uuid, \
mock.patch('rtv.terminal.Terminal.open_browser') as open_browser, \ mock.patch('rtv.terminal.Terminal.open_browser') as open_browser, \
mock.patch('rtv.oauth.ioloop') as ioloop, \ mock.patch('rtv.oauth.HTTPServer') as http_server, \
mock.patch('rtv.oauth.httpserver'), \
mock.patch.object(oauth.reddit, 'user'), \ mock.patch.object(oauth.reddit, 'user'), \
mock.patch('time.sleep'): mock.patch('time.sleep'):
io = ioloop.IOLoop.current.return_value
# Valid authorization # Valid authorization
oauth.term._display = False oauth.term._display = False
params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None} params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None}
uuid.return_value = params['state'] uuid.return_value = params['state']
io.start.side_effect = lambda *_: oauth.params.update(**params)
def serve_forever():
oauth.params.update(**params)
http_server.return_value.serve_forever.side_effect = serve_forever
oauth.authorize() oauth.authorize()
assert not open_browser.called assert open_browser.called
oauth.reddit.get_access_information.assert_called_with( oauth.reddit.get_access_information.assert_called_with(
reddit, params['code']) reddit, params['code'])
assert oauth.config.refresh_token is not None assert oauth.config.refresh_token is not None
assert oauth.config.save_refresh_token.called assert oauth.config.save_refresh_token.called
stdscr.reset_mock() stdscr.reset_mock()
oauth.reddit.get_access_information.reset_mock() oauth.reddit.get_access_information.reset_mock()
oauth.config.save_refresh_token.reset_mock() oauth.config.save_refresh_token.reset_mock()
oauth.http_server = None oauth.server = None
# The next authorization should skip the oauth process # The next authorization should skip the oauth process
oauth.config.refresh_token = refresh_token oauth.config.refresh_token = refresh_token
oauth.authorize() oauth.authorize()
assert oauth.reddit.user is not None assert oauth.reddit.user is not None
assert oauth.http_server is None assert oauth.server is None
stdscr.reset_mock() stdscr.reset_mock()
# Invalid state returned # Invalid state returned
@@ -129,7 +148,6 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token):
oauth.term._display = True oauth.term._display = True
params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None} params = {'state': 'uniqueid', 'code': 'secretcode', 'error': None}
uuid.return_value = params['state'] uuid.return_value = params['state']
io.start.side_effect = lambda *_: oauth.params.update(**params)
oauth.authorize() oauth.authorize()
assert open_browser.called assert open_browser.called
@@ -137,11 +155,12 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token):
reddit, params['code']) reddit, params['code'])
assert oauth.config.refresh_token is not None assert oauth.config.refresh_token is not None
assert oauth.config.save_refresh_token.called assert oauth.config.save_refresh_token.called
stdscr.reset_mock() stdscr.reset_mock()
oauth.reddit.get_access_information.reset_mock() oauth.reddit.get_access_information.reset_mock()
oauth.config.refresh_token = None oauth.config.refresh_token = None
oauth.config.save_refresh_token.reset_mock() oauth.config.save_refresh_token.reset_mock()
oauth.http_server = None oauth.server = None
# Exceptions when logging in are handled correctly # Exceptions when logging in are handled correctly
with mock.patch.object(oauth.reddit, 'get_access_information'): with mock.patch.object(oauth.reddit, 'get_access_information'):
@@ -149,13 +168,4 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token):
oauth.reddit.get_access_information.side_effect = exception oauth.reddit.get_access_information.side_effect = exception
oauth.authorize() oauth.authorize()
assert isinstance(oauth.term.loader.exception, OAuthException) assert isinstance(oauth.term.loader.exception, OAuthException)
assert not oauth.config.save_refresh_token.called assert not oauth.config.save_refresh_token.called
def test_oauth_clear_data(oauth):
oauth.config.refresh_token = 'secrettoken'
oauth.reddit.refresh_token = 'secrettoken'
oauth.clear_oauth_data()
assert oauth.config.refresh_token is None
assert oauth.reddit.refresh_token is None

View File

@@ -389,6 +389,7 @@ def test_open_link_subprocess(terminal):
with mock.patch('time.sleep'), \ with mock.patch('time.sleep'), \
mock.patch('os.system'), \ mock.patch('os.system'), \
mock.patch('subprocess.Popen') as Popen, \
mock.patch('six.moves.input') as six_input, \ mock.patch('six.moves.input') as six_input, \
mock.patch.object(terminal, 'get_mailcap_entry'): mock.patch.object(terminal, 'get_mailcap_entry'):
@@ -398,6 +399,9 @@ def test_open_link_subprocess(terminal):
six_input.reset_mock() six_input.reset_mock()
os.system.reset_mock() os.system.reset_mock()
terminal.stdscr.subwin.addstr.reset_mock() terminal.stdscr.subwin.addstr.reset_mock()
Popen.return_value.communicate.return_value = '', 'stderr message'
Popen.return_value.poll.return_value = 0
Popen.return_value.wait.return_value = 0
def get_error(): def get_error():
# Check if an error message was printed to the terminal # Check if an error message was printed to the terminal
@@ -415,6 +419,8 @@ def test_open_link_subprocess(terminal):
# Non-blocking failure # Non-blocking failure
reset_mock() reset_mock()
Popen.return_value.poll.return_value = 127
Popen.return_value.wait.return_value = 127
entry = ('fake .', 'fake %s') entry = ('fake .', 'fake %s')
terminal.get_mailcap_entry.return_value = entry terminal.get_mailcap_entry.return_value = entry
terminal.open_link(url) terminal.open_link(url)
@@ -431,6 +437,8 @@ def test_open_link_subprocess(terminal):
# needsterminal failure # needsterminal failure
reset_mock() reset_mock()
Popen.return_value.poll.return_value = 127
Popen.return_value.wait.return_value = 127
entry = ('fake .', 'fake %s; needsterminal') entry = ('fake .', 'fake %s; needsterminal')
terminal.get_mailcap_entry.return_value = entry terminal.get_mailcap_entry.return_value = entry
terminal.open_link(url) terminal.open_link(url)
@@ -447,6 +455,8 @@ def test_open_link_subprocess(terminal):
# copiousoutput failure # copiousoutput failure
reset_mock() reset_mock()
Popen.return_value.poll.return_value = 127
Popen.return_value.wait.return_value = 127
entry = ('fake .', 'fake %s; needsterminal; copiousoutput') entry = ('fake .', 'fake %s; needsterminal; copiousoutput')
terminal.get_mailcap_entry.return_value = entry terminal.get_mailcap_entry.return_value = entry
terminal.open_link(url) terminal.open_link(url)