From 4a7074b4cfe96f68d71ba3078ba344f25ad17507 Mon Sep 17 00:00:00 2001
From: Bryan Tong Minh <bryan.tongminh@gmail.com>
Date: Sat, 23 Aug 2008 19:07:06 +0000
Subject: [PATCH] * Added max_items parameter to list * Raise an
 ApiDisabledError if the API is disabled * Properly fail on HTTPS sites

---
 mwclient/REFERENCE.txt |  3 ++-
 mwclient/client.py     | 10 +++++++++-
 mwclient/errors.py     |  8 ++++++--
 mwclient/http.py       | 28 ++++++++++++++++++++++++----
 mwclient/listing.py    |  9 ++++++++-
 5 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/mwclient/REFERENCE.txt b/mwclient/REFERENCE.txt
index 983df50..9e8f03b 100644
--- a/mwclient/REFERENCE.txt
+++ b/mwclient/REFERENCE.txt
@@ -79,7 +79,8 @@ their two letter prefix Exceptions:
 Properties and generators are implemented as Python generators. Their limit 
 parameter is only an indication of the number of items in one chunk. It is not
 the total limit. Doing list(generator(limit = limit)) will return ALL items of 
-generator, and not be limitted by the limit value.
+generator, and not be limitted by the limit value. Use list(generator(
+max_items = max_items)) to limit the amount of items returned.
 Default chunk size is generally the maximum chunk size.
 
 == Links ==
diff --git a/mwclient/client.py b/mwclient/client.py
index b6ccc6f..7f7cf3b 100644
--- a/mwclient/client.py
+++ b/mwclient/client.py
@@ -189,6 +189,8 @@ class Site(object):
 					raise
 				else:
 					self.wait(token)
+			except errors.HTTPRedirectError:
+				raise
 			except errors.HTTPError:
 				self.wait(token)
 			except ValueError:
@@ -198,7 +200,13 @@ class Site(object):
 		kwargs['action'] = action
 		kwargs['format'] = 'json'
 		data = self._query_string(*args, **kwargs)
-		return simplejson.load(self.raw_call('api', data))
+		json = self.raw_call('api', data).read()
+		try:
+			return simplejson.loads(json)
+		except ValueError:
+			if json.startswith('MediaWiki API is not enabled for this site.'):
+				raise errors.APIDisabledError
+			raise
 				
 	def raw_index(self, action, *args, **kwargs):
 		kwargs['action'] = action
diff --git a/mwclient/errors.py b/mwclient/errors.py
index 297e6b9..2a1dbd4 100644
--- a/mwclient/errors.py
+++ b/mwclient/errors.py
@@ -4,12 +4,16 @@ class MwClientError(RuntimeError):
 class MediaWikiVersionError(MwClientError):
 	pass
 
-	
+class APIDisabledError(MwClientError):
+	pass
+
 class HTTPError(MwClientError):
 	pass
 class HTTPStatusError(MwClientError):
 	pass
-	
+class HTTPRedirectError(HTTPError):
+	pass
+
 class MaximumRetriesExceeded(MwClientError):
 	pass
 	
diff --git a/mwclient/http.py b/mwclient/http.py
index d9ca695..6fb1619 100644
--- a/mwclient/http.py
+++ b/mwclient/http.py
@@ -50,11 +50,14 @@ class Cookie(object):
 		self.value = value
 		
 class HTTPPersistentConnection(object):
+	http_class = httplib.HTTPConnection
+	scheme_name = 'http'
+	
 	def __init__(self, host, pool = None):
 		self.cookies = {}
 		self.pool = pool
 		if pool: self.cookies = pool.cookies
-		self._conn = httplib.HTTPConnection(host)
+		self._conn = self.http_class(host)
 		self._conn.connect()
 		self.last_request = time.time()
 		
@@ -110,15 +113,25 @@ class HTTPPersistentConnection(object):
 					del headers['Content-Length']
 				method = 'GET'
 				data = ''
+			old_path = path
 			path = location[2]
 			if location[4]: path = path + '?' + location[4]
 			
+			print location[0]
+			if location[0].lower() != self.scheme_name:
+				raise errors.HTTPRedirectError, ('Only HTTP connections are supported',
+					res.getheader('Location'))
+			
 			if self.pool is None:
 				if location[1] != host: 
-					raise errors.HTTPError, ('Redirecting to different hosts not supported', 
+					raise errors.HTTPRedirectError, ('Redirecting to different hosts not supported', 
 						res.getheader('Location'))
+
 				return self.request(method, host, path, headers, data)
 			else:
+				if host == location[1] and path == old_path:
+					conn = self.__class__(location[1], self.pool)
+					self.pool.append(([location[1]], conn))
 				return self.pool.request(method, location[1], path, 
 					headers, data, stream_iter, raise_on_not_ok, auto_redirect)
 			
@@ -153,6 +166,10 @@ class HTTPConnection(HTTPPersistentConnection):
 			stream_iter, raise_on_not_ok, auto_redirect)
 		return res
 
+class HTTPSPersistentConnection(HTTPPersistentConnection):
+	http_class = httplib.HTTPSConnection
+	scheme_name = 'https'
+
 	
 class HTTPPool(list):
 	def __init__(self):
@@ -161,13 +178,15 @@ class HTTPPool(list):
 	def find_connection(self, host):
 		for hosts, conn in self:
 			if host in hosts: return conn
-				
+		
+		redirected_host = None
 		for hosts, conn in self:
 			status, headers = conn.head(host, '/')
 			if status == 200:
 				hosts.append(host)
 				return conn
 			if status >= 300 and status <= 399:
+				# BROKEN!
 				headers = dict(headers)
 				location = urlparse.urlparse(headers.get('location', ''))
 				if location[1] == host:
@@ -191,4 +210,5 @@ class HTTPPool(list):
 			headers, data, stream_iter, raise_on_not_ok, auto_redirect)
 	def close(self):
 		for hosts, conn in self:
-			conn.close()
\ No newline at end of file
+			conn.close()
+			
diff --git a/mwclient/listing.py b/mwclient/listing.py
index a9a50b4..253bbe9 100644
--- a/mwclient/listing.py
+++ b/mwclient/listing.py
@@ -2,7 +2,7 @@ import client, page
 import compatibility
 
 class List(object):
-	def __init__(self, site, list_name, prefix, limit = None, return_values = None, *args, **kwargs):
+	def __init__(self, site, list_name, prefix, limit = None, return_values = None, max_items = None, *args, **kwargs):
 		# NOTE: Fix limit
 		self.site = site
 		self.list_name = list_name
@@ -15,6 +15,9 @@ class List(object):
 		if limit is None: limit = site.api_limit
 		self.args[self.prefix + 'limit'] = str(limit)
 		
+		self.count = 0
+		self.max_items = max_items
+		
 		self._iter = iter(xrange(0))
 		
 		self.last = False
@@ -25,8 +28,12 @@ class List(object):
 		return self
 		
 	def next(self, full = False):
+		if self.max_items is not None:
+			if self.count >= self.max_items:
+				raise StopIteration
 		try:
 			item = self._iter.next()
+			self.count += 1
 			if 'timestamp' in item:
 				item['timestamp'] = client.parse_timestamp(item['timestamp'])
 			if full: return item
-- 
GitLab