diff --git a/rtv/content.py b/rtv/content.py index adddece..1eee21e 100644 --- a/rtv/content.py +++ b/rtv/content.py @@ -182,31 +182,22 @@ class Content(object): return data @staticmethod - def strip_praw_subscription(subscription): + def strip_praw_reddit(reddit): """ - Parse through a subscription and return a dict with data ready to be + Parse through a reddit object and return a dict with data ready to be displayed through the terminal. """ data = {} - data['object'] = subscription - data['type'] = 'Subscription' - data['name'] = "/r/" + subscription.display_name - data['title'] = subscription.title - return data - - @staticmethod - def strip_praw_multireddit(multireddit): - """ - Parse through a multireddits and return a dict with data ready to be - displayed through the terminal. - """ - - data = {} - data['object'] = multireddit - data['type'] = 'Multireddit' - data['name'] = multireddit.path - data['title'] = multireddit.description_md + data['object'] = reddit + if isinstance(reddit, praw.objects.Subreddit): + data['type'] = 'Subreddit' + data['name'] = "/r/" + reddit.display_name + data['title'] = reddit.title + elif isinstance(reddit, praw.objects.Multireddit): + data['type'] = 'Multireddit' + data['name'] = reddit.path + data['title'] = reddit.description_md return data @staticmethod @@ -533,97 +524,48 @@ class SubredditContent(Content): return data -class SubscriptionContent(Content): +class ListRedditsContent(Content): - def __init__(self, subscriptions, loader): + def __init__(self, name, reddits, loader): - self.name = "Subscriptions" + self.name = name self.order = None self._loader = loader - self._subscriptions = subscriptions - self._subscription_data = [] + self._reddits = reddits + self._reddit_data = [] try: self.get(0) except IndexError: - raise exceptions.SubscriptionError('No subscriptions') + raise exceptions.ListRedditsError('No {}'.format(self.name)) @classmethod - def from_user(cls, reddit, loader): - subscriptions = reddit.get_my_subreddits(limit=None) - return cls(subscriptions, loader) + def from_user(cls, name, reddits, loader): + reddits = (r for r in reddits) + return cls(name, reddits, loader) def get(self, index, n_cols=70): """ - Grab the `i`th subscription, with the title field formatted to fit + Grab the `i`th reddit, with the title field formatted to fit inside of a window of width `n_cols` """ if index < 0: raise IndexError - while index >= len(self._subscription_data): + while index >= len(self._reddit_data): try: - with self._loader('Loading subscriptions'): - subscription = next(self._subscriptions) + with self._loader('Loading {}'.format(self.name)): + reddit = next(self._reddits) if self._loader.exception: raise IndexError except StopIteration: raise IndexError else: - data = self.strip_praw_subscription(subscription) - self._subscription_data.append(data) + data = self.strip_praw_reddit(reddit) + self._reddit_data.append(data) - data = self._subscription_data[index] - data['split_title'] = self.wrap_text(data['title'], width=n_cols) - data['n_rows'] = len(data['split_title']) + 1 - data['offset'] = 0 - - return data - - -class MultiredditContent(Content): - - def __init__(self, multireddits, loader): - - self.name = "Multireddits" - self.order = None - self._loader = loader - self._multireddits = multireddits - self._multireddit_data = [] - - try: - self.get(0) - except IndexError: - raise exceptions.SubscriptionError('No multireddits') - - @classmethod - def from_user(cls, reddit, multireddits, loader): - multireddits = (m for m in multireddits) - return cls(multireddits, loader) - - def get(self, index, n_cols=70): - """ - Grab the `i`th subscription, with the title field formatted to fit - inside of a window of width `n_cols` - """ - - if index < 0: - raise IndexError - - while index >= len(self._multireddit_data): - try: - with self._loader('Loading multireddits'): - multireddit = next(self._multireddits) - if self._loader.exception: - raise IndexError - except StopIteration: - raise IndexError - else: - data = self.strip_praw_multireddit(multireddit) - self._multireddit_data.append(data) - - data = self._multireddit_data[index] + data = self._reddit_data[index] data['split_title'] = self.wrap_text(data['title'], width=n_cols) data['n_rows'] = len(data['split_title']) + 1 data['offset'] = 0 diff --git a/rtv/exceptions.py b/rtv/exceptions.py index 1d9f976..31d9f5a 100644 --- a/rtv/exceptions.py +++ b/rtv/exceptions.py @@ -26,8 +26,8 @@ class SubredditError(RTVError): "Subreddit could not be reached" -class SubscriptionError(RTVError): - "Subscriptions could not be fetched" +class ListRedditsError(RTVError): + "List of reddits could not be fetched" class ProgramError(RTVError): @@ -39,4 +39,4 @@ class BrowserError(RTVError): class TemporaryFileError(RTVError): - "Indicates that an error has occurred and the file should not be deleted" \ No newline at end of file + "Indicates that an error has occurred and the file should not be deleted" diff --git a/rtv/reddits.py b/rtv/reddits.py new file mode 100644 index 0000000..a220833 --- /dev/null +++ b/rtv/reddits.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import curses + +from .page import Page, PageController +from .content import ListRedditsContent +from .objects import Color, Navigator, Command + + +class ListRedditsController(PageController): + character_map = {} + + +class ListRedditsPage(Page): + + def __init__(self, reddit, name, reddits, term, config, oauth): + super(ListRedditsPage, self).__init__(reddit, term, config, oauth) + + self.controller = ListRedditsController(self, keymap=config.keymap) + self.name = name + self.content = ListRedditsContent.from_user(name, reddits, term.loader) + self.nav = Navigator(self.content.get) + self.reddit_data = None + + @ListRedditsController.register(Command('REFRESH')) + def refresh_content(self, order=None, name=None): + "Re-download all reddits and reset the page index" + + # reddit listings does not support sorting by order + if order: + self.term.flash() + return + + with self.term.loader(): + self.content = ListRedditsContent.from_user(self.name, self.reddit, + self.term.loader) + if not self.term.loader.exception: + self.nav = Navigator(self.content.get) + + @ListRedditsController.register(Command('SUBSCRIPTION_SELECT')) + def select_reddit(self): + "Store the selected reddit and return to the subreddit page" + + self.reddit_data = self.content.get(self.nav.absolute_index) + self.active = False + + @ListRedditsController.register(Command('SUBSCRIPTION_EXIT')) + def close_subscriptions(self): + "Close list of reddits and return to the subreddit page" + + self.active = False + + def _draw_banner(self): + # Subscriptions can't be sorted, so disable showing the order menu + pass + + def _draw_item(self, win, data, inverted): + n_rows, n_cols = win.getmaxyx() + n_cols -= 1 # Leave space for the cursor in the first column + + # Handle the case where the window is not large enough to fit the data. + valid_rows = range(0, n_rows) + offset = 0 if not inverted else -(data['n_rows'] - n_rows) + + row = offset + if row in valid_rows: + attr = curses.A_BOLD | Color.YELLOW + self.term.add_line(win, '{name}'.format(**data), row, 1, attr) + + row = offset + 1 + for row, text in enumerate(data['split_title'], start=row): + if row in valid_rows: + self.term.add_line(win, text, row, 1) diff --git a/rtv/subreddit.py b/rtv/subreddit.py index 51a0e0b..2141d81 100644 --- a/rtv/subreddit.py +++ b/rtv/subreddit.py @@ -9,8 +9,7 @@ from .content import SubredditContent from .page import Page, PageController, logged_in from .objects import Navigator, Color, Command from .submission import SubmissionPage -from .multireddits import MultiredditPage -from .subscription import SubscriptionPage +from .reddits import ListRedditsPage from .exceptions import TemporaryFileError @@ -158,8 +157,9 @@ class SubredditPage(Page): "Open user subscriptions page" with self.term.loader('Loading subscriptions'): - page = SubscriptionPage( - self.reddit, self.term, self.config, self.oauth) + page = ListRedditsPage(self.reddit, 'My Subscriptions', + self.reddit.get_my_subreddits(limit=None), self.term, + self.config, self.oauth) if self.term.loader.exception: return @@ -167,9 +167,9 @@ class SubredditPage(Page): # When the user has chosen a subreddit in the subscriptions list, # refresh content with the selected subreddit - if page.subreddit_data is not None: + if page.reddit_data is not None: self.refresh_content(order='ignore', - name=page.subreddit_data['name']) + name=page.reddit_data['name']) @SubredditController.register(Command('MULTIREDDIT_OPEN_SUBSCRIPTIONS')) @logged_in @@ -177,9 +177,9 @@ class SubredditPage(Page): "Open user multireddit subscriptions page" with self.term.loader('Loading multireddits'): - page = MultiredditPage( - self.reddit, self.reddit.get_my_multireddits(), - self.term, self.config) + page = ListRedditsPage(self.reddit, + 'My Multireddits', self.reddit.get_my_multireddits(), + self.term, self.config, self.oauth) if self.term.loader.exception: return @@ -187,10 +187,9 @@ class SubredditPage(Page): # When the user has chosen a subreddit in the subscriptions list, # refresh content with the selected subreddit - if page.multireddit_data is not None: + if page.reddit_data is not None: self.refresh_content(order='ignore', - name=page.multireddit_data['name']) - + name=page.reddit_data['name']) def _draw_item(self, win, data, inverted): diff --git a/tests/conftest.py b/tests/conftest.py index 32b22c5..d35d2ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from rtv.config import Config from rtv.terminal import Terminal from rtv.subreddit import SubredditPage from rtv.submission import SubmissionPage -from rtv.subscription import SubscriptionPage +from rtv.reddits import ListRedditsPage try: from unittest import mock @@ -220,13 +220,14 @@ def subreddit_page(reddit, terminal, config, oauth): @pytest.fixture() -def subscription_page(reddit, terminal, config, oauth, refresh_token): - # Must be logged in to view your subscriptions - config.refresh_token = refresh_token - oauth.authorize() +def reddits(reddit, terminal, config, oauth): + return reddit.get_popular_subreddits() + +@pytest.fixture() +def list_reddits_page(reddit, name, reddits, terminal, config, oauth): with terminal.loader(): - page = SubscriptionPage(reddit, terminal, config, oauth) + page = ListRedditsPage(reddit, name, reddits, terminal, config, oauth) assert terminal.loader.exception is None page.draw() return page diff --git a/tests/test_reddits.py b/tests/test_reddits.py new file mode 100644 index 0000000..5b2945e --- /dev/null +++ b/tests/test_reddits.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import curses + +import praw +import pytest + +from rtv.reddits import ListRedditsPage + +try: + from unittest import mock +except ImportError: + import mock + + +def test_list_reddits_page_construct(reddit, reddits, terminal, config, + oauth, refresh_token): + window = terminal.stdscr.subwin + title = 'Subscriptions' + + with terminal.loader(): + page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth) + assert terminal.loader.exception is None + + page.draw() + + # Header - Title + window.addstr.assert_any_call(0, 0, title.encode('utf-8')) + + # Header - Name + name = reddit.user.name.encode('utf-8') + window.addstr.assert_any_call(0, 59, name) + + # Banner shouldn't be drawn + menu = ('[1]hot ' + '[2]top ' + '[3]rising ' + '[4]new ' + '[5]controversial').encode('utf-8') + with pytest.raises(AssertionError): + window.addstr.assert_any_call(0, 0, menu) + + # Cursor - 2 lines + window.subwin.chgat.assert_any_call(0, 0, 1, 262144) + window.subwin.chgat.assert_any_call(1, 0, 1, 262144) + + # Reload with a smaller terminal window + terminal.stdscr.ncols = 20 + terminal.stdscr.nlines = 10 + with terminal.loader(): + page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth) + assert terminal.loader.exception is None + + page.draw() + + +def test_list_reddits_refresh(list_reddits_page): + + # Refresh content - invalid order + list_reddits_page.controller.trigger('2') + assert curses.flash.called + curses.flash.reset_mock() + + # Refresh content + list_reddits_page.controller.trigger('r') + assert not curses.flash.called + + +def test_list_reddits_move(list_reddits_page): + + # Test movement + with mock.patch.object(list_reddits_page, 'clear_input_queue'): + + # Move cursor to the bottom of the page + while not curses.flash.called: + list_reddits_page.controller.trigger('j') + curses.flash.reset_mock() + assert list_reddits_page.nav.inverted + assert (list_reddits_page.nav.absolute_index == + len(list_reddits_page.content._reddit_data) - 1) + + # And back to the top + for i in range(list_reddits_page.nav.absolute_index): + list_reddits_page.controller.trigger('k') + assert not curses.flash.called + assert list_reddits_page.nav.absolute_index == 0 + assert not list_reddits_page.nav.inverted + + # Can't go up any further + list_reddits_page.controller.trigger('k') + assert curses.flash.called + assert list_reddits_page.nav.absolute_index == 0 + assert not list_reddits_page.nav.inverted + + # Page down should move the last item to the top + n = len(list_reddits_page._subwindows) + list_reddits_page.controller.trigger('n') + assert list_reddits_page.nav.absolute_index == n - 1 + + # And page up should move back up, but possibly not to the first item + list_reddits_page.controller.trigger('m') + + +def test_list_reddits_select(list_reddits_page): + + # Select a subreddit + list_reddits_page.controller.trigger(curses.KEY_ENTER) + assert list_reddits_page.reddit_data is not None + assert list_reddits_page.active is False + + +def test_list_reddits_close(list_reddits_page): + + # Close the list of reddits page + list_reddits_page.reddit_data = None + list_reddits_page.active = None + list_reddits_page.controller.trigger('h') + assert list_reddits_page.reddit_data is None + assert list_reddits_page.active is False + + +def test_list_reddits_page_invalid(list_reddits_page): + + # Test that other commands don't crash + methods = [ + 'a', # Upvote + 'z', # Downvote + 'd', # Delete + 'e', # Edit + ] + for ch in methods: + curses.flash.reset_mock() + list_reddits_page.controller.trigger(ch) + assert curses.flash.called