diff --git a/mwclient/errors.py b/mwclient/errors.py index 3dcc83a6044934789b96691da4f8006dd071e891..b7fdd1868c4c864240ad917c303f7a69c1e23867 100644 --- a/mwclient/errors.py +++ b/mwclient/errors.py @@ -35,7 +35,16 @@ class EditError(MwClientError): class ProtectedPageError(EditError, InsufficientPermission): - pass + + def __init__(self, page, code=None, info=None): + self.page = page + self.code = code + self.info = info + + def __str__(self): + if self.info is not None: + return self.info + return 'You do not have the "edit" right.' class FileExists(EditError): @@ -56,6 +65,20 @@ class OAuthAuthorizationError(LoginError): return self.info +class AssertUserFailedError(LoginError): + + def __init__(self): + self.message = 'By default, mwclient protects you from ' + \ + 'accidentally editing without being logged in. If you ' + \ + 'actually want to edit without logging in, you can set ' + \ + 'force_login on the Site object to False.' + + LoginError.__init__(self) + + def __str__(self): + return self.message + + class EmailError(MwClientError): pass diff --git a/mwclient/page.py b/mwclient/page.py index 54c763ce07ea10b5144af67e14d5c05244755306..b0abcf2bf3947c700a99d0619becdafbc75ad2c0 100644 --- a/mwclient/page.py +++ b/mwclient/page.py @@ -173,14 +173,7 @@ class Page(object): """Update the text of a section or the whole page by performing an edit operation. """ if not self.site.logged_in and self.site.force_login: - # Should we really check for this? - raise mwclient.errors.LoginError( - self.site, - 'By default, mwclient protects you from accidentally editing ' - 'without being logged in. ' - 'If you actually want to edit without logging in, ' - 'you can set force_login on the Site object to False.' - ) + raise mwclient.errors.AssertUserFailedError() if self.site.blocked: raise mwclient.errors.UserBlocked(self.site.blocked) if not self.can('edit'): @@ -205,6 +198,9 @@ class Page(object): data.update(kwargs) + if self.site.force_login: + data['assert'] = 'user' + def do_edit(): result = self.site.post('edit', title=self.name, text=text, summary=summary, token=self.get_token('edit'), @@ -238,10 +234,14 @@ class Page(object): raise mwclient.errors.EditError(self, summary, e.info) elif e.code in {'protectedtitle', 'cantcreate', 'cantcreate-anon', 'noimageredirect-anon', 'noimageredirect', 'noedit-anon', - 'noedit'}: + 'noedit', 'protectedpage', 'cascadeprotected', + 'customcssjsprotected', + 'protectednamespace-interface', 'protectednamespace'}: raise mwclient.errors.ProtectedPageError(self, e.code, e.info) + elif e.code == 'assertuserfailed': + raise mwclient.errors.AssertUserFailedError() else: - raise + raise e def move(self, new_title, reason='', move_talk=True, no_redirect=False): """Move (rename) page to new_title. diff --git a/tests/test_page.py b/tests/test_page.py index d7babe51b678b5e2ba3e8f2848721c4ddce34e28..4469b212f6f6af78066e9e3d94731237a2fbef77 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -10,6 +10,7 @@ import mock import mwclient from mwclient.page import Page from mwclient.client import Site +from mwclient.errors import APIError, AssertUserFailedError, ProtectedPageError try: import json @@ -215,8 +216,11 @@ class TestPageApiArgs(unittest.TestCase): 'ns': 0, 'pageid': 2, 'revisions': [{'*': 'Hello world', 'timestamp': '2014-08-29T22:25:15Z'}], 'title': title }}}} - def get_last_api_call_args(self): - args, kwargs = self.site.get.call_args + def get_last_api_call_args(self, http_method='POST'): + if http_method == 'GET': + args, kwargs = self.site.get.call_args + else: + args, kwargs = self.site.post.call_args action = args[0] args = args[1:] kwargs.update(args) @@ -228,7 +232,7 @@ class TestPageApiArgs(unittest.TestCase): def test_get_page_text(self): # Check that page.text() works, and that a correct API call is made text = self.page.text() - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert text == self.page_text assert args == { @@ -253,17 +257,68 @@ class TestPageApiArgs(unittest.TestCase): def test_get_section_text(self): # Check that the 'rvsection' parameter is sent to the API text = self.page.text(section=0) - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert args['rvsection'] == '0' def test_get_text_expanded(self): # Check that the 'rvexpandtemplates' parameter is sent to the API text = self.page.text(expandtemplates=True) - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert args['rvexpandtemplates'] == '1' + def test_assertuser_true(self): + # Check that assert=user is sent when force_login=True + self.site.blocked = False + self.site.rights = ['read', 'edit'] + self.site.logged_in = True + self.site.force_login = True + + self.site.api.return_value = { + 'edit': {'result': 'Ok'} + } + self.page.save('Some text') + args = self.get_last_api_call_args() + + assert args['assert'] == 'user' + + def test_assertuser_false(self): + # Check that assert=user is not sent when force_login=False + self.site.blocked = False + self.site.rights = ['read', 'edit'] + self.site.logged_in = False + self.site.force_login = False + + self.site.api.return_value = { + 'edit': {'result': 'Ok'} + } + self.page.save('Some text') + args = self.get_last_api_call_args() + + assert 'assert' not in args + + def test_handle_edit_error_assertuserfailed(self): + # Check that AssertUserFailedError is triggered + api_error = APIError('assertuserfailed', + 'Assertion that the user is logged in failed', + 'See https://en.wikipedia.org/w/api.php for API usage') + + with pytest.raises(AssertUserFailedError): + self.page.handle_edit_error(api_error, 'n/a') + + def test_handle_edit_error_protected(self): + # Check that ProtectedPageError is triggered + api_error = APIError('protectedpage', + 'The "editprotected" right is required to edit this page', + 'See https://en.wikipedia.org/w/api.php for API usage') + + with pytest.raises(ProtectedPageError) as pp_error: + self.page.handle_edit_error(api_error, 'n/a') + + assert pp_error.value.code == 'protectedpage' + assert str(pp_error.value) == 'The "editprotected" right is required to edit this page' + if __name__ == '__main__': unittest.main()