Generalize SubscriptionPage to handle lists of reddits

This commit is contained in:
woorst
2016-07-17 16:17:50 -05:00
parent 86cf7d0391
commit 92d16ad15f
6 changed files with 257 additions and 106 deletions

View File

@@ -182,31 +182,22 @@ class Content(object):
return data return data
@staticmethod @staticmethod
def strip_praw_subscription(subscription): def strip_praw_reddit(reddit):
""" """
Parse through a subscription and return a dict with data ready to be Parse through a reddit object and return a dict with data ready to be
displayed through the terminal. displayed through the terminal.
""" """
data = {} data = {}
data['object'] = subscription data['object'] = reddit
data['type'] = 'Subscription' if isinstance(reddit, praw.objects.Subreddit):
data['name'] = "/r/" + subscription.display_name data['type'] = 'Subreddit'
data['title'] = subscription.title data['name'] = "/r/" + reddit.display_name
return data data['title'] = reddit.title
elif isinstance(reddit, praw.objects.Multireddit):
@staticmethod
def strip_praw_multireddit(multireddit):
"""
Parse through a multireddits and return a dict with data ready to be
displayed through the terminal.
"""
data = {}
data['object'] = multireddit
data['type'] = 'Multireddit' data['type'] = 'Multireddit'
data['name'] = multireddit.path data['name'] = reddit.path
data['title'] = multireddit.description_md data['title'] = reddit.description_md
return data return data
@staticmethod @staticmethod
@@ -533,97 +524,48 @@ class SubredditContent(Content):
return data return data
class SubscriptionContent(Content): class ListRedditsContent(Content):
def __init__(self, subscriptions, loader): def __init__(self, name, reddits, loader):
self.name = "Subscriptions" self.name = name
self.order = None self.order = None
self._loader = loader self._loader = loader
self._subscriptions = subscriptions self._reddits = reddits
self._subscription_data = [] self._reddit_data = []
try: try:
self.get(0) self.get(0)
except IndexError: except IndexError:
raise exceptions.SubscriptionError('No subscriptions') raise exceptions.ListRedditsError('No {}'.format(self.name))
@classmethod @classmethod
def from_user(cls, reddit, loader): def from_user(cls, name, reddits, loader):
subscriptions = reddit.get_my_subreddits(limit=None) reddits = (r for r in reddits)
return cls(subscriptions, loader) return cls(name, reddits, loader)
def get(self, index, n_cols=70): def get(self, index, n_cols=70):
""" """
Grab the `i`th subscription, with the title field formatted to fit Grab the `i`th reddit, with the title field formatted to fit
inside of a window of width `n_cols` inside of a window of width `n_cols`
""" """
if index < 0: if index < 0:
raise IndexError raise IndexError
while index >= len(self._subscription_data): while index >= len(self._reddit_data):
try: try:
with self._loader('Loading subscriptions'): with self._loader('Loading {}'.format(self.name)):
subscription = next(self._subscriptions) reddit = next(self._reddits)
if self._loader.exception: if self._loader.exception:
raise IndexError raise IndexError
except StopIteration: except StopIteration:
raise IndexError raise IndexError
else: else:
data = self.strip_praw_subscription(subscription) data = self.strip_praw_reddit(reddit)
self._subscription_data.append(data) self._reddit_data.append(data)
data = self._subscription_data[index] data = self._reddit_data[index]
data['split_title'] = self.wrap_text(data['title'], width=n_cols)
data['n_rows'] = len(data['split_title']) + 1
data['offset'] = 0
return data
class MultiredditContent(Content):
def __init__(self, multireddits, loader):
self.name = "Multireddits"
self.order = None
self._loader = loader
self._multireddits = multireddits
self._multireddit_data = []
try:
self.get(0)
except IndexError:
raise exceptions.SubscriptionError('No multireddits')
@classmethod
def from_user(cls, reddit, multireddits, loader):
multireddits = (m for m in multireddits)
return cls(multireddits, loader)
def get(self, index, n_cols=70):
"""
Grab the `i`th subscription, with the title field formatted to fit
inside of a window of width `n_cols`
"""
if index < 0:
raise IndexError
while index >= len(self._multireddit_data):
try:
with self._loader('Loading multireddits'):
multireddit = next(self._multireddits)
if self._loader.exception:
raise IndexError
except StopIteration:
raise IndexError
else:
data = self.strip_praw_multireddit(multireddit)
self._multireddit_data.append(data)
data = self._multireddit_data[index]
data['split_title'] = self.wrap_text(data['title'], width=n_cols) data['split_title'] = self.wrap_text(data['title'], width=n_cols)
data['n_rows'] = len(data['split_title']) + 1 data['n_rows'] = len(data['split_title']) + 1
data['offset'] = 0 data['offset'] = 0

View File

@@ -26,8 +26,8 @@ class SubredditError(RTVError):
"Subreddit could not be reached" "Subreddit could not be reached"
class SubscriptionError(RTVError): class ListRedditsError(RTVError):
"Subscriptions could not be fetched" "List of reddits could not be fetched"
class ProgramError(RTVError): class ProgramError(RTVError):

74
rtv/reddits.py Normal file
View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import curses
from .page import Page, PageController
from .content import ListRedditsContent
from .objects import Color, Navigator, Command
class ListRedditsController(PageController):
character_map = {}
class ListRedditsPage(Page):
def __init__(self, reddit, name, reddits, term, config, oauth):
super(ListRedditsPage, self).__init__(reddit, term, config, oauth)
self.controller = ListRedditsController(self, keymap=config.keymap)
self.name = name
self.content = ListRedditsContent.from_user(name, reddits, term.loader)
self.nav = Navigator(self.content.get)
self.reddit_data = None
@ListRedditsController.register(Command('REFRESH'))
def refresh_content(self, order=None, name=None):
"Re-download all reddits and reset the page index"
# reddit listings does not support sorting by order
if order:
self.term.flash()
return
with self.term.loader():
self.content = ListRedditsContent.from_user(self.name, self.reddit,
self.term.loader)
if not self.term.loader.exception:
self.nav = Navigator(self.content.get)
@ListRedditsController.register(Command('SUBSCRIPTION_SELECT'))
def select_reddit(self):
"Store the selected reddit and return to the subreddit page"
self.reddit_data = self.content.get(self.nav.absolute_index)
self.active = False
@ListRedditsController.register(Command('SUBSCRIPTION_EXIT'))
def close_subscriptions(self):
"Close list of reddits and return to the subreddit page"
self.active = False
def _draw_banner(self):
# Subscriptions can't be sorted, so disable showing the order menu
pass
def _draw_item(self, win, data, inverted):
n_rows, n_cols = win.getmaxyx()
n_cols -= 1 # Leave space for the cursor in the first column
# Handle the case where the window is not large enough to fit the data.
valid_rows = range(0, n_rows)
offset = 0 if not inverted else -(data['n_rows'] - n_rows)
row = offset
if row in valid_rows:
attr = curses.A_BOLD | Color.YELLOW
self.term.add_line(win, '{name}'.format(**data), row, 1, attr)
row = offset + 1
for row, text in enumerate(data['split_title'], start=row):
if row in valid_rows:
self.term.add_line(win, text, row, 1)

View File

@@ -9,8 +9,7 @@ from .content import SubredditContent
from .page import Page, PageController, logged_in from .page import Page, PageController, logged_in
from .objects import Navigator, Color, Command from .objects import Navigator, Color, Command
from .submission import SubmissionPage from .submission import SubmissionPage
from .multireddits import MultiredditPage from .reddits import ListRedditsPage
from .subscription import SubscriptionPage
from .exceptions import TemporaryFileError from .exceptions import TemporaryFileError
@@ -158,8 +157,9 @@ class SubredditPage(Page):
"Open user subscriptions page" "Open user subscriptions page"
with self.term.loader('Loading subscriptions'): with self.term.loader('Loading subscriptions'):
page = SubscriptionPage( page = ListRedditsPage(self.reddit, 'My Subscriptions',
self.reddit, self.term, self.config, self.oauth) self.reddit.get_my_subreddits(limit=None), self.term,
self.config, self.oauth)
if self.term.loader.exception: if self.term.loader.exception:
return return
@@ -167,9 +167,9 @@ class SubredditPage(Page):
# When the user has chosen a subreddit in the subscriptions list, # When the user has chosen a subreddit in the subscriptions list,
# refresh content with the selected subreddit # refresh content with the selected subreddit
if page.subreddit_data is not None: if page.reddit_data is not None:
self.refresh_content(order='ignore', self.refresh_content(order='ignore',
name=page.subreddit_data['name']) name=page.reddit_data['name'])
@SubredditController.register(Command('MULTIREDDIT_OPEN_SUBSCRIPTIONS')) @SubredditController.register(Command('MULTIREDDIT_OPEN_SUBSCRIPTIONS'))
@logged_in @logged_in
@@ -177,9 +177,9 @@ class SubredditPage(Page):
"Open user multireddit subscriptions page" "Open user multireddit subscriptions page"
with self.term.loader('Loading multireddits'): with self.term.loader('Loading multireddits'):
page = MultiredditPage( page = ListRedditsPage(self.reddit,
self.reddit, self.reddit.get_my_multireddits(), 'My Multireddits', self.reddit.get_my_multireddits(),
self.term, self.config) self.term, self.config, self.oauth)
if self.term.loader.exception: if self.term.loader.exception:
return return
@@ -187,10 +187,9 @@ class SubredditPage(Page):
# When the user has chosen a subreddit in the subscriptions list, # When the user has chosen a subreddit in the subscriptions list,
# refresh content with the selected subreddit # refresh content with the selected subreddit
if page.multireddit_data is not None: if page.reddit_data is not None:
self.refresh_content(order='ignore', self.refresh_content(order='ignore',
name=page.multireddit_data['name']) name=page.reddit_data['name'])
def _draw_item(self, win, data, inverted): def _draw_item(self, win, data, inverted):

View File

@@ -17,7 +17,7 @@ from rtv.config import Config
from rtv.terminal import Terminal from rtv.terminal import Terminal
from rtv.subreddit import SubredditPage from rtv.subreddit import SubredditPage
from rtv.submission import SubmissionPage from rtv.submission import SubmissionPage
from rtv.subscription import SubscriptionPage from rtv.reddits import ListRedditsPage
try: try:
from unittest import mock from unittest import mock
@@ -220,13 +220,14 @@ def subreddit_page(reddit, terminal, config, oauth):
@pytest.fixture() @pytest.fixture()
def subscription_page(reddit, terminal, config, oauth, refresh_token): def reddits(reddit, terminal, config, oauth):
# Must be logged in to view your subscriptions return reddit.get_popular_subreddits()
config.refresh_token = refresh_token
oauth.authorize()
@pytest.fixture()
def list_reddits_page(reddit, name, reddits, terminal, config, oauth):
with terminal.loader(): with terminal.loader():
page = SubscriptionPage(reddit, terminal, config, oauth) page = ListRedditsPage(reddit, name, reddits, terminal, config, oauth)
assert terminal.loader.exception is None assert terminal.loader.exception is None
page.draw() page.draw()
return page return page

135
tests/test_reddits.py Normal file
View File

@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import curses
import praw
import pytest
from rtv.reddits import ListRedditsPage
try:
from unittest import mock
except ImportError:
import mock
def test_list_reddits_page_construct(reddit, reddits, terminal, config,
oauth, refresh_token):
window = terminal.stdscr.subwin
title = 'Subscriptions'
with terminal.loader():
page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth)
assert terminal.loader.exception is None
page.draw()
# Header - Title
window.addstr.assert_any_call(0, 0, title.encode('utf-8'))
# Header - Name
name = reddit.user.name.encode('utf-8')
window.addstr.assert_any_call(0, 59, name)
# Banner shouldn't be drawn
menu = ('[1]hot '
'[2]top '
'[3]rising '
'[4]new '
'[5]controversial').encode('utf-8')
with pytest.raises(AssertionError):
window.addstr.assert_any_call(0, 0, menu)
# Cursor - 2 lines
window.subwin.chgat.assert_any_call(0, 0, 1, 262144)
window.subwin.chgat.assert_any_call(1, 0, 1, 262144)
# Reload with a smaller terminal window
terminal.stdscr.ncols = 20
terminal.stdscr.nlines = 10
with terminal.loader():
page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth)
assert terminal.loader.exception is None
page.draw()
def test_list_reddits_refresh(list_reddits_page):
# Refresh content - invalid order
list_reddits_page.controller.trigger('2')
assert curses.flash.called
curses.flash.reset_mock()
# Refresh content
list_reddits_page.controller.trigger('r')
assert not curses.flash.called
def test_list_reddits_move(list_reddits_page):
# Test movement
with mock.patch.object(list_reddits_page, 'clear_input_queue'):
# Move cursor to the bottom of the page
while not curses.flash.called:
list_reddits_page.controller.trigger('j')
curses.flash.reset_mock()
assert list_reddits_page.nav.inverted
assert (list_reddits_page.nav.absolute_index ==
len(list_reddits_page.content._reddit_data) - 1)
# And back to the top
for i in range(list_reddits_page.nav.absolute_index):
list_reddits_page.controller.trigger('k')
assert not curses.flash.called
assert list_reddits_page.nav.absolute_index == 0
assert not list_reddits_page.nav.inverted
# Can't go up any further
list_reddits_page.controller.trigger('k')
assert curses.flash.called
assert list_reddits_page.nav.absolute_index == 0
assert not list_reddits_page.nav.inverted
# Page down should move the last item to the top
n = len(list_reddits_page._subwindows)
list_reddits_page.controller.trigger('n')
assert list_reddits_page.nav.absolute_index == n - 1
# And page up should move back up, but possibly not to the first item
list_reddits_page.controller.trigger('m')
def test_list_reddits_select(list_reddits_page):
# Select a subreddit
list_reddits_page.controller.trigger(curses.KEY_ENTER)
assert list_reddits_page.reddit_data is not None
assert list_reddits_page.active is False
def test_list_reddits_close(list_reddits_page):
# Close the list of reddits page
list_reddits_page.reddit_data = None
list_reddits_page.active = None
list_reddits_page.controller.trigger('h')
assert list_reddits_page.reddit_data is None
assert list_reddits_page.active is False
def test_list_reddits_page_invalid(list_reddits_page):
# Test that other commands don't crash
methods = [
'a', # Upvote
'z', # Downvote
'd', # Delete
'e', # Edit
]
for ch in methods:
curses.flash.reset_mock()
list_reddits_page.controller.trigger(ch)
assert curses.flash.called