diff --git a/rtv/content.py b/rtv/content.py index 1eee21e..ad6bf92 100644 --- a/rtv/content.py +++ b/rtv/content.py @@ -388,7 +388,7 @@ class SubredditContent(Content): # Strip leading, trailing, and redundant backslashes name_list = [seg for seg in name.strip(' /').split('/') if seg] name_order = None - if name_list[0] in ['r', 'u', 'user', 'domain'] and len(name_list) > 1: + if len(name_list) > 1 and name_list[0] in ['r', 'u', 'user', 'domain']: listing, name_list = name_list[0], name_list[1:] if len(name_list) == 2: name, name_order = name_list @@ -540,8 +540,8 @@ class ListRedditsContent(Content): raise exceptions.ListRedditsError('No {}'.format(self.name)) @classmethod - def from_user(cls, name, reddits, loader): - reddits = (r for r in reddits) + def from_func(cls, name, func, loader): + reddits = (r for r in func()) return cls(name, reddits, loader) def get(self, index, n_cols=70): diff --git a/rtv/reddits.py b/rtv/reddits.py index a220833..4786250 100644 --- a/rtv/reddits.py +++ b/rtv/reddits.py @@ -14,12 +14,13 @@ class ListRedditsController(PageController): class ListRedditsPage(Page): - def __init__(self, reddit, name, reddits, term, config, oauth): + def __init__(self, reddit, name, func, term, config, oauth): super(ListRedditsPage, self).__init__(reddit, term, config, oauth) self.controller = ListRedditsController(self, keymap=config.keymap) self.name = name - self.content = ListRedditsContent.from_user(name, reddits, term.loader) + self.func = func + self.content = ListRedditsContent.from_func(name, func, term.loader) self.nav = Navigator(self.content.get) self.reddit_data = None @@ -33,8 +34,8 @@ class ListRedditsPage(Page): return with self.term.loader(): - self.content = ListRedditsContent.from_user(self.name, self.reddit, - self.term.loader) + self.content = ListRedditsContent.from_func(self.name, + self.func, self.term.loader) if not self.term.loader.exception: self.nav = Navigator(self.content.get) diff --git a/rtv/subreddit.py b/rtv/subreddit.py index 2141d81..e9d1aa0 100644 --- a/rtv/subreddit.py +++ b/rtv/subreddit.py @@ -156,10 +156,10 @@ class SubredditPage(Page): def open_subscriptions(self): "Open user subscriptions page" + func = lambda : self.reddit.get_my_subreddits(limit=None) with self.term.loader('Loading subscriptions'): - page = ListRedditsPage(self.reddit, 'My Subscriptions', - self.reddit.get_my_subreddits(limit=None), self.term, - self.config, self.oauth) + page = ListRedditsPage(self.reddit, 'My Subscriptions', func, + self.term, self.config, self.oauth) if self.term.loader.exception: return @@ -176,10 +176,10 @@ class SubredditPage(Page): def open_multireddit_subscriptions(self): "Open user multireddit subscriptions page" + func = lambda : self.reddit.get_my_multireddits() with self.term.loader('Loading multireddits'): page = ListRedditsPage(self.reddit, - 'My Multireddits', self.reddit.get_my_multireddits(), - self.term, self.config, self.oauth) + 'My Multireddits', func, self.term, self.config, self.oauth) if self.term.loader.exception: return diff --git a/tests/conftest.py b/tests/conftest.py index 056eb5b..df470aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -210,14 +210,11 @@ def subreddit_page(reddit, terminal, config, oauth): @pytest.fixture() -def reddits(reddit, terminal, config, oauth): - return reddit.get_popular_subreddits() - - -@pytest.fixture() -def list_reddits_page(reddit, name, reddits, terminal, config, oauth): +def list_reddits_page(reddit, terminal, config, oauth): + title = 'Popular Subreddits' + func = reddit.get_popular_subreddits with terminal.loader(): - page = ListRedditsPage(reddit, name, reddits, terminal, config, oauth) + page = ListRedditsPage(reddit, title, func, terminal, config, oauth) assert terminal.loader.exception is None page.draw() return page diff --git a/tests/test_content.py b/tests/test_content.py index 7fc624c..c328511 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -9,7 +9,7 @@ import praw import pytest from rtv.content import ( - Content, SubmissionContent, SubredditContent, SubscriptionContent) + Content, SubmissionContent, SubredditContent, ListRedditsContent) from rtv import exceptions @@ -291,7 +291,6 @@ def test_content_subreddit_from_name(reddit, terminal): # Queries SubredditContent.from_name(reddit, 'front', terminal.loader, query='pea') SubredditContent.from_name(reddit, 'python', terminal.loader, query='pea') - SubredditContent.from_name(reddit, 'me', terminal.loader, query='pea') def test_content_subreddit_multireddit(reddit, terminal): @@ -333,23 +332,16 @@ def test_content_subreddit_me(reddit, oauth, refresh_token, terminal): assert isinstance(terminal.loader.exception, exceptions.SubredditError) -def test_content_subscription(reddit, oauth, refresh_token, terminal): +def test_content_list_reddits(reddit, oauth, refresh_token, terminal): - # Not logged in + title = 'Popular Subreddits' + func = reddit.get_popular_subreddits with terminal.loader(): - SubscriptionContent.from_user(reddit, terminal.loader) - assert isinstance( - terminal.loader.exception, praw.errors.LoginOrScopeRequired) - - # Logged in - oauth.config.refresh_token = refresh_token - oauth.authorize() - with terminal.loader(): - content = SubscriptionContent.from_user(reddit, terminal.loader) + content = ListRedditsContent.from_func(title, func, terminal.loader) assert terminal.loader.exception is None # These are static - assert content.name == 'Subscriptions' + assert content.name == title assert content.order is None # Validate content @@ -361,11 +353,11 @@ def test_content_subscription(reddit, oauth, refresh_token, terminal): assert not isinstance(val, six.binary_type) -def test_content_subscription_empty(terminal): +def test_content_list_reddits_empty(terminal): - # Simulate an empty subscription generator - subscriptions = iter([]) + # Simulate an empty list of reddits + func = lambda : iter([]) with terminal.loader(): - SubscriptionContent(subscriptions, terminal.loader) - assert isinstance(terminal.loader.exception, exceptions.SubscriptionError) + ListRedditsContent('test', func(), terminal.loader()) + assert isinstance(terminal.loader.exception, exceptions.ListRedditsError) diff --git a/tests/test_reddits.py b/tests/test_reddits.py index 5b2945e..f128c47 100644 --- a/tests/test_reddits.py +++ b/tests/test_reddits.py @@ -14,13 +14,14 @@ except ImportError: import mock -def test_list_reddits_page_construct(reddit, reddits, terminal, config, +def test_list_reddits_page_construct(reddit, terminal, config, oauth, refresh_token): window = terminal.stdscr.subwin - title = 'Subscriptions' + title = 'Popular Subreddits' + func = reddit.get_popular_subreddits with terminal.loader(): - page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth) + page = ListRedditsPage(reddit, title, func, terminal, config, oauth) assert terminal.loader.exception is None page.draw() @@ -49,7 +50,7 @@ def test_list_reddits_page_construct(reddit, reddits, terminal, config, terminal.stdscr.ncols = 20 terminal.stdscr.nlines = 10 with terminal.loader(): - page = ListRedditsPage(reddit, title, reddits, terminal, config, oauth) + page = ListRedditsPage(reddit, title, func, terminal, config, oauth) assert terminal.loader.exception is None page.draw()