diff --git a/rtv/content.py b/rtv/content.py index b2a4aa9..914899b 100644 --- a/rtv/content.py +++ b/rtv/content.py @@ -68,6 +68,22 @@ def humanize_timestamp(utc_timestamp, verbose=False): years = months // 12 return ('%d years ago' % years) if verbose else ('%dyr' % years) +def validate_backslashes(name): + """ + Makes sure backslashes in subreddit name work with for_name() + """ + + if name[0] == '/': + raise SubredditNameError(name) + + if name[-1] == '/': + name = name[:-1] + + if name.count('/') > 1: + raise SubredditNameError(name) + + return name + @contextmanager def default_loader(self): yield @@ -309,9 +325,7 @@ class SubredditContent(BaseContent): self._submission_data = [] @classmethod - def from_name(cls, reddit, name, loader=default_loader): - - display_type = 'normal' + def from_name(cls, reddit, name, loader=default_loader, display_type = 'hot'): if name == 'front': return cls('Front Page', reddit.get_front_page(limit=None), loader) @@ -324,27 +338,53 @@ class SubredditContent(BaseContent): else: - if '/' in name: - if name[0] == '/': - raise SubredditNameError(name) + name = validate_backslashes(name) + if '/' in name: name, display_type = name.split('/') + if display_type not in ['new', 'top', 'hot', 'rising', 'controversial']: + raise SubredditNameError(name) + + if name == 'front': + + if display_type == 'new': + return cls('New', reddit.get_new(limit=None), loader) + + elif display_type == 'top': + return cls('Top', reddit.get_top(limit=None), loader) + + elif display_type == 'hot': + return cls('Front Page', reddit.get_front_page(limit=None), loader) + + elif display_type == 'rising': + return cls('Rising', reddit.get_rising(limit=None), loader) + + elif display_type == 'controversial': + return cls('Controversial', reddit.get_controversial(limit=None), loader) + try: with loader(): sub = reddit.get_subreddit(name, fetch=True) except praw.errors.ClientException: raise SubredditNameError(name) - if display_type == 'top': - return cls('/r/'+sub.display_name+'/top', sub.get_top_from_all(limit=None), loader) - - elif display_type == 'new': + if display_type == 'new': return cls('/r/'+sub.display_name+'/new', sub.get_new(limit=None), loader) - else: + elif display_type == 'top': + return cls('/r/'+sub.display_name+'/top', sub.get_top_from_all(limit=None), loader) + + elif display_type == 'hot': return cls('/r/'+sub.display_name, sub.get_hot(limit=None), loader) + elif display_type == 'rising': + return cls('/r/'+sub.display_name+'/rising', sub.get_rising(limit=None), loader) + + elif display_type == 'controversial': + return cls('/r/'+sub.display_name+'/controversial', sub.get_controversial(limit=None), loader) + + def get(self, index, n_cols=70): """ Grab the `i`th submission, with the title field formatted to fit inside