diff --git a/PKG-INFO b/PKG-INFO index e7a66b80723a0bf4745cad7baf86a9820d148fd0..c540d98259439fb7d977cb17f54f08c631df9c40 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,12 +1,12 @@ Metadata-Version: 1.1 Name: python3-openid -Version: 3.0.9 +Version: 3.1.0 Summary: OpenID support for modern servers and consumers. Home-page: http://github.com/necaris/python3-openid Author: Rami Chowdhury Author-email: rami.chowdhury@gmail.com License: UNKNOWN -Download-URL: http://github.com/necaris/python3-openid/tarball/v3.0.9 +Download-URL: http://github.com/necaris/python3-openid/tarball/v3.1.0 Description: This is a set of Python packages to support use of the OpenID decentralized identity system in your application, update to Python 3. Want to enable single sign-on for your web site? Use the openid.consumer diff --git a/admin/builddiscover.py b/admin/build_discover_data.py similarity index 94% rename from admin/builddiscover.py rename to admin/build_discover_data.py index 2837d76e05423c429e38f1bf8b3b2b70bf59a820..45391f169337b3a8b64e8cd870bee72b3aeb960a 100755 --- a/admin/builddiscover.py +++ b/admin/build_discover_data.py @@ -37,6 +37,7 @@ manifest_header = """\ """ + def buildDiscover(base_url, out_dir): """ Convert all files in a directory to apache mod_asis files in @@ -48,8 +49,8 @@ def buildDiscover(base_url, out_dir): """Helper to generate an output data file for a given test name.""" template = test_data[test_name] - data = discoverdata.fillTemplate( - test_name, template, base_url, discoverdata.example_xrds) + data = discoverdata.fillTemplate(test_name, template, base_url, + discoverdata.example_xrds) out_file_name = os.path.join(out_dir, test_name) out_file = open(out_file_name, 'w', encoding="utf-8") @@ -73,5 +74,6 @@ def buildDiscover(base_url, out_dir): for chunk in manifest: manifest_file.write(chunk) + if __name__ == '__main__': buildDiscover(*sys.argv[1:]) diff --git a/admin/fixperms b/admin/fixperms deleted file mode 100755 index d0303e11ff54f9e886c1a79a8fabd7cb171761cd..0000000000000000000000000000000000000000 --- a/admin/fixperms +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -cat - <<EOF | xargs chmod +x -admin/builddiscover.py -admin/fixperms -admin/makechangelog -admin/pythonsource -admin/runtests -admin/setversion -admin/tagrelease -EOF \ No newline at end of file diff --git a/admin/gettlds.py b/admin/get_tlds.py similarity index 88% rename from admin/gettlds.py rename to admin/get_tlds.py index 579b13355fa13f05aaa174e4b80022e5d062f73f..ad875c09e3db1c559df7358bf32c8bdb977fd842 100644 --- a/admin/gettlds.py +++ b/admin/get_tlds.py @@ -19,22 +19,10 @@ LANGS = { "'", # line prefix "|", # separator "|' .", # line suffix - r")\.?$/'" # suffix - ), - 'python': ( - "['", - "'", - "', '", - "',", - "']" - ), - 'ruby': ( - "%w'", - "", - " ", - "", - "'" + r")\.?$/'" # suffix ), + 'python': ("['", "'", "', '", "',", "']"), + 'ruby': ("%w'", "", " ", "", "'"), } if __name__ == '__main__': diff --git a/admin/next_version.py b/admin/next_version.py new file mode 100644 index 0000000000000000000000000000000000000000..de83e85eb64ae973f205908475269fa91dd22984 --- /dev/null +++ b/admin/next_version.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +""" +Compute the next release version of the library, using `--major`, `--minor`, +or `--patch` arguments to determine the level at which the version is to be +incremented. +""" +import sys +from os.path import abspath, join, dirname + +if __name__ == '__main__': + sys.path.append(abspath(join(dirname(__file__), '..'))) + + import openid + + major, minor, patch = openid.version_info + pieces = None + + if '--major' in sys.argv: + pieces = (major + 1, 0, 0) + elif '--minor' in sys.argv: + pieces = (major, minor + 1, 0) + elif '--patch' in sys.argv: + pieces = (major, minor, patch + 1) + + if pieces: + print('.'.join(map(str, pieces)), end='') + else: + print('Major, minor, or patch?', file=sys.stderr) + sys.exit(1) diff --git a/admin/patch_version.py b/admin/patch_version.py new file mode 100644 index 0000000000000000000000000000000000000000..79f059f096af3d51f7b883396e0a9ad9dd44b083 --- /dev/null +++ b/admin/patch_version.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +""" +Update the `version_info` embedded in the library to the given version. +""" +import sys +from os.path import abspath, join, dirname + +if __name__ == '__main__': + try: + major, minor, patch = map(int, sys.argv[1].split('.')) + except (IndexError, ValueError): + print('Need version string in form MAJOR.MINOR.PATCH', file=sys.stderr) + sys.exit(1) + + TARGET = abspath(join(dirname(__file__), '..', 'openid', '__init__.py')) + + with open(TARGET, 'r', encoding='utf8') as f: + lines = f.readlines() + for i, l in enumerate(lines): + if l.startswith('version_info'): + v_info = '({}, {}, {})'.format(major, minor, patch) + lines[i] = 'version_info = {}\n\n'.format(v_info) + break + + with open(TARGET, 'w', encoding='utf8') as f: + f.writelines(lines) diff --git a/admin/runtests b/admin/runtests deleted file mode 100755 index 1343ea110f5928157b8d39027734ecf225791399..0000000000000000000000000000000000000000 --- a/admin/runtests +++ /dev/null @@ -1,209 +0,0 @@ -#!/usr/bin/env python -import os.path -import sys -import unittest - - -test_modules = [ - 'cryptutil', - 'oidutil', - 'dh', - ] - -def fixpath(): - try: - d = os.path.dirname(__file__) - except NameError: - d = os.path.dirname(sys.argv[0]) - parent = os.path.normpath(os.path.join(d, '..')) - if parent not in sys.path: - print ("putting %s in sys.path" % (parent,)) - sys.path.insert(0, parent) - -def otherTests(): - failed = [] - for module_name in test_modules: - print ('Testing %s...' % (module_name,)) - sys.stdout.flush() - module_name = 'openid.test.' + module_name - try: - test_mod = __import__(module_name, {}, {}, [None]) - except ImportError: - print ('Failed to import test %r' % (module_name,)) - failed.append(module_name) - else: - try: - test_mod.test() - except (SystemExit, KeyboardInterrupt): - raise - except: - sys.excepthook(*sys.exc_info()) - failed.append(module_name) - else: - print ('Succeeded.') - - - return failed - -def pyunitTests(): - import unittest - pyunit_module_names = [ - 'server', - 'consumer', - 'message', - 'symbol', - 'etxrd', - 'xri', - 'xrires', - 'association_response', - 'auth_request', - 'negotiation', - 'verifydisco', - 'sreg', - 'ax', - 'pape', - 'pape_draft2', - 'pape_draft5', - 'rpverify', - 'extension', - ] - - pyunit_modules = [ - __import__('openid.test.test_%s' % (name,), {}, {}, ['unused']) - for name in pyunit_module_names - ] - - try: - from openid.test import test_examples - except ImportError as e: - if 'twill' in str(e): - raise unittest.SkipTest('Skipping test_examples. ' - 'Could not import twill.') - else: - raise - else: - pyunit_modules.append(test_examples) - - # Some modules have data-driven tests, and they use custom methods - # to build the test suite: - custom_module_names = [ - 'kvform', - 'linkparse', - 'oidutil', - 'storetest', - 'test_accept', - 'test_association', - 'test_discover', - 'test_fetchers', - 'test_htmldiscover', - 'test_nonce', - 'test_openidyadis', - 'test_parsehtml', - 'test_urinorm', - 'test_yadis_discover', - 'trustroot', - ] - - loader = unittest.TestLoader() - s = unittest.TestSuite() - - for m in pyunit_modules: - s.addTest(loader.loadTestsFromModule(m)) - - for name in custom_module_names: - m = __import__('openid.test.%s' % (name,), {}, {}, ['unused']) - try: - s.addTest(m.pyUnitTests()) - except AttributeError as ex: - # because the AttributeError doesn't actually say which - # object it was. - print ("Error loading tests from %s:" % (name,)) - raise - - runner = unittest.TextTestRunner() # verbosity=2) - - return runner.run(s) - - - -def splitDir(d, count): - # in python2.4 and above, it's easier to spell this as - # d.rsplit(os.sep, count) - for i in range(count): - d = os.path.dirname(d) - return d - - - -def _import_djopenid(): - """Import djopenid from examples/ - - It's not in sys.path, and I don't really want to put it in sys.path. - """ - import types - thisfile = os.path.abspath(sys.modules[__name__].__file__) - topDir = splitDir(thisfile, 2) - djdir = os.path.join(topDir, 'examples', 'djopenid') - - djinit = os.path.join(djdir, '__init__.py') - - djopenid = types.ModuleType('djopenid') - with open(djinit) as f: - code = compile(f.read(), djinit, 'exec') - exec(code, djopenid.__dict__) - djopenid.__file__ = djinit - - # __path__ is the magic that makes child modules of the djopenid package - # importable. New feature in python 2.3, see PEP 302. - djopenid.__path__ = [djdir] - sys.modules['djopenid'] = djopenid - - - -def django_tests(): - """Runs tests from examples/djopenid. - - @returns: number of failed tests. - """ - import os - # Django uses this to find out where its settings are. - os.environ['DJANGO_SETTINGS_MODULE'] = 'djopenid.settings' - - _import_djopenid() - - try: - import django.test.simple - except ImportError as e: - raise unittest.SkipTest("django.test.simple not found; " - "django examples not tested.") - import djopenid.server.models, djopenid.consumer.models - print ("Testing Django examples:") - - # These tests do get put in to a pyunit test suite, so we could run them - # with the other pyunit tests, but django also establishes a test database - # for them, so we let it do that thing instead. - return django.test.simple.run_tests([djopenid.server.models, - djopenid.consumer.models]) - -try: - bool -except NameError: - def bool(x): - return not not x - -def main(): - fixpath() - other_failed = otherTests() - pyunit_result = pyunitTests() - django_failures = django_tests() - - if other_failed: - print ('Failures:', ', '.join(other_failed)) - - failed = (bool(other_failed) or - bool(not pyunit_result.wasSuccessful()) or - (django_failures > 0)) - return failed - -if __name__ == '__main__': - sys.exit(main() and 1 or 0) diff --git a/admin/setversion b/admin/setversion deleted file mode 100755 index ea2b20cb8a2ba535ab36a36f52a3679073f28933..0000000000000000000000000000000000000000 --- a/admin/setversion +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -cat <<EOF | \ - xargs sed -i 's/\[library version:[^]]*\]/[library version:'"$1"']/' -setup.py -openid/__init__.py -EOF diff --git a/contrib/associate b/contrib/associate index 4cb05c31c11f2857db3aac7a4c1df2a0c46f40ad..3d5142e9ea8db87f60f36493d575db85f7262687 100755 --- a/contrib/associate +++ b/contrib/associate @@ -10,6 +10,7 @@ from openid.consumer.discover import OpenIDServiceEndpoint from datetime import datetime + def verboseAssociation(assoc): """A more verbose representation of an Association. """ @@ -24,11 +25,12 @@ def verboseAssociation(assoc): """ return fmt % d + def main(): if not sys.argv[1:]: - print "Usage: %s ENDPOINT_URL..." % (sys.argv[0],) + print("Usage: %s ENDPOINT_URL..." % (sys.argv[0],)) for endpoint_url in sys.argv[1:]: - print "Associating with", endpoint_url + print("Associating with", endpoint_url) # This makes it clear why j3h made AssociationManager when we # did the ruby port. We can't invoke requestAssociation @@ -39,9 +41,10 @@ def main(): c = consumer.GenericConsumer(store) auth_req = c.begin(endpoint) if auth_req.assoc: - print verboseAssociation(auth_req.assoc) + print(verboseAssociation(auth_req.assoc)) else: - print " ...no association." + print(" ...no association.") + if __name__ == '__main__': main() diff --git a/contrib/openid-parse b/contrib/openid-parse index 21ab18dfdcce2bd729ac72e856b14f9ab87ba344..6f8a532028ff03d7a7d970bc4b9a816d6cfca265 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -8,19 +8,20 @@ Requires the 'xsel' program to get the contents of the clipboard. """ from pprint import pformat -from urlparse import urlsplit, urlunsplit +from urllib.parse import urlsplit, urlunsplit import cgi, re, subprocess, sys from openid import message OPENID_SORT_ORDER = ['mode', 'identity', 'claimed_id'] + class NoQuery(Exception): def __init__(self, url): self.url = url def __str__(self): - return "No query in url %s" % (self.url,) + return "No query in url %s" % (self.url, ) def getClipboard(): @@ -42,7 +43,7 @@ def main(): for url in urls: try: queries.append(queryFromURL(url)) - except NoQuery, err: + except NoQuery as err: errors.append(err) queries.extend(queriesFromLogs(source)) @@ -51,10 +52,10 @@ def main(): output.append('at %s:\n%s' % (where, openidFromQuery(query))) if output: - print '\n\n'.join(output) + print('\n\n'.join(output)) elif errors: for err in errors: - print err + print(err) def queryFromURL(url): @@ -73,7 +74,7 @@ def openidFromQuery(query): try: msg = message.Message.fromPostArgs(unlistify(query)) s = formatOpenIDMessage(msg) - except Exception, err: + except Exception as err: # XXX - side effect. sys.stderr.write(str(err)) s = pformat(query) @@ -124,6 +125,7 @@ def queriesFromLogs(s): return [(match.group(1), cgi.parse_qs(match.group(2))) for match in qre.finditer(s)] + def queriesFromPostdata(s): # This looks for query data in a line that starts POSTDATA=. # Tamperdata outputs such lines. If there's a 'Host=' in that block, @@ -133,16 +135,20 @@ def queriesFromPostdata(s): return [(match.group('host') or 'POSTDATA', cgi.parse_qs(match.group('query'))) for match in qre.finditer(s)] + def find_urls(s): # Regular expression borrowed from urlscan # by Daniel Burrows <dburrows@debian.org>, GPL. - urlinternalpattern=r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' - urltrailingpattern=r'[{}a-zA-Z/\-_0-9%&=+#]' + urlinternalpattern = r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' + urltrailingpattern = r'[{}a-zA-Z/\-_0-9%&=+#]' httpurlpattern = r'(?:https?://' + urlinternalpattern + r'*' + urltrailingpattern + r')' # Used to guess that blah.blah.blah.TLD is a URL. - tlds=['biz', 'com', 'edu', 'info', 'org'] - guessedurlpattern=r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join(tlds) + '))' - urlre = re.compile(r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') + tlds = ['biz', 'com', 'edu', 'info', 'org'] + guessedurlpattern = r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join( + tlds) + '))' + urlre = re.compile( + r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') return [match.group(1) for match in urlre.finditer(s)] diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 6a346826af565590301d058ebfc54e344e4097c7..5fafd4f14831fced94c8e92a2ddacae1b01e1166 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -23,15 +23,17 @@ from optparse import OptionParser def askForPassword(): return getpass.getpass("DB Password: ") -def askForConfirmation(dbname,tablename): + +def askForConfirmation(dbname, tablename): print """The table %s from the database %s will be dropped, and - an empty table with the new nonce table schema will replace it."""%( - tablename, dbname) + an empty table with the new nonce table schema will replace it.""" % ( + tablename, dbname) return raw_input("Continue? ").lower().strip().startswith('y') + def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR, @@ -39,13 +41,14 @@ def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), UNIQUE(server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url BLOB, @@ -54,13 +57,14 @@ def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): PRIMARY KEY (server_url(255), timestamp, salt) ) TYPE=InnoDB; - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR(2047), @@ -68,32 +72,49 @@ def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), PRIMARY KEY (server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() db_conn.commit() + def main(argv=None): parser = OptionParser() - parser.add_option("-u", "--user", dest="username", - default=os.environ.get('USER'), - help="User name to use to connect to the DB. " - "Defaults to USER environment variable.") - parser.add_option('-t', '--table', dest='tablename', default='oid_nonces', - help='The name of the nonce table to drop and recreate. ' - ' Defaults to "oid_nonces", the default table name for ' - 'the openid stores.') - parser.add_option('--mysql', dest='mysql_db_name', - help='Upgrade a table from this MySQL database. ' - 'Requires username for database.') - parser.add_option('--pg', '--postgresql', dest='postgres_db_name', - help='Upgrade a table from this PostgreSQL database. ' - 'Requires username for database.') - parser.add_option('--sqlite', dest='sqlite_db_name', - help='Upgrade a table from this SQLite database file.') - parser.add_option('--host', dest='db_host', - default='localhost', - help='Host on which to find MySQL or PostgreSQL DB.') + parser.add_option( + "-u", + "--user", + dest="username", + default=os.environ.get('USER'), + help="User name to use to connect to the DB. " + "Defaults to USER environment variable.") + parser.add_option( + '-t', + '--table', + dest='tablename', + default='oid_nonces', + help='The name of the nonce table to drop and recreate. ' + ' Defaults to "oid_nonces", the default table name for ' + 'the openid stores.') + parser.add_option( + '--mysql', + dest='mysql_db_name', + help='Upgrade a table from this MySQL database. ' + 'Requires username for database.') + parser.add_option( + '--pg', + '--postgresql', + dest='postgres_db_name', + help='Upgrade a table from this PostgreSQL database. ' + 'Requires username for database.') + parser.add_option( + '--sqlite', + dest='sqlite_db_name', + help='Upgrade a table from this SQLite database file.') + parser.add_option( + '--host', + dest='db_host', + default='localhost', + help='Host on which to find MySQL or PostgreSQL DB.') (options, args) = parser.parse_args(argv) db_conn = None @@ -129,10 +150,11 @@ def main(argv=None): return 1 try: - db_conn = psycopg2.connect(database = options.postgres_db_name, - user = options.username, - host = options.db_host, - password = password) + db_conn = psycopg2.connect( + database=options.postgres_db_name, + user=options.username, + host=options.db_host, + password=password) except Exception, e: print "Could not connect to PostgreSQL database:", str(e) return 1 diff --git a/contrib/upgrade-store-1.1-to-2.0~ b/contrib/upgrade-store-1.1-to-2.0~ deleted file mode 100644 index 2f6f0ebd6ad73ab6ab6ddcfc6304faa479c85f45..0000000000000000000000000000000000000000 --- a/contrib/upgrade-store-1.1-to-2.0~ +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python -# SQL Store Upgrade Script -# for version 1.x to 2.0 of the OpenID library. -# Doesn't depend on the openid library, so you can run this python -# script to update databases for ruby or PHP as well. -# -# Testers note: -# -# A SQLite3 db with the 1.2 schema exists in -# openid/test/data/openid-1.2-consumer-sqlitestore.db if you want something -# to try upgrading. -# -# TODO: -# * test data for mysql and postgresql. -# * automated tests. - -import os -import getpass -import sys -from optparse import OptionParser - - -def askForPassword(): - return getpass.getpass("DB Password: ") - -def askForConfirmation(dbname,tablename): - print """The table %s from the database %s will be dropped, and - an empty table with the new nonce table schema will replace it."""%( - tablename, dbname) - return raw_input("Continue? ").lower().strip().startswith('y') - -def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): - cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) - sql = """ - CREATE TABLE %s ( - server_url VARCHAR, - timestamp INTEGER, - salt CHAR(40), - UNIQUE(server_url, timestamp, salt) - ); - """%nonce_table_name - cur.execute(sql) - cur.close() - -def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): - cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) - sql = """ - CREATE TABLE %s ( - server_url BLOB, - timestamp INTEGER, - salt CHAR(40), - PRIMARY KEY (server_url(255), timestamp, salt) - ) - TYPE=InnoDB; - """%nonce_table_name - cur.execute(sql) - cur.close() - -def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): - cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) - sql = """ - CREATE TABLE %s ( - server_url VARCHAR(2047), - timestamp INTEGER, - salt CHAR(40), - PRIMARY KEY (server_url, timestamp, salt) - ); - """%nonce_table_name - cur.execute(sql) - cur.close() - db_conn.commit() - -def main(argv=None): - parser = OptionParser() - parser.add_option("-u", "--user", dest="username", - default=os.environ.get('USER'), - help="User name to use to connect to the DB. " - "Defaults to USER environment variable.") - parser.add_option('-t', '--table', dest='tablename', default='oid_nonces', - help='The name of the nonce table to drop and recreate. ' - ' Defaults to "oid_nonces", the default table name for ' - 'the openid stores.') - parser.add_option('--mysql', dest='mysql_db_name', - help='Upgrade a table from this MySQL database. ' - 'Requires username for database.') - parser.add_option('--pg', '--postgresql', dest='postgres_db_name', - help='Upgrade a table from this PostgreSQL database. ' - 'Requires username for database.') - parser.add_option('--sqlite', dest='sqlite_db_name', - help='Upgrade a table from this SQLite database file.') - parser.add_option('--host', dest='db_host', - default='localhost', - help='Host on which to find MySQL or PostgreSQL DB.') - (options, args) = parser.parse_args(argv) - - db_conn = None - - if options.sqlite_db_name: - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - print "You must have pysqlite2 installed in your PYTHONPATH." - return 1 - try: - db_conn = sqlite.connect(options.sqlite_db_name) - except Exception, e: - print "Could not connect to SQLite database:", str(e) - return 1 - - if askForConfirmation(options.sqlite_db_name, options.tablename): - doSQLiteUpgrade(db_conn, nonce_table_name=options.tablename) - - if options.postgres_db_name: - if not options.username: - print "A username is required to open a PostgreSQL Database." - return 1 - password = askForPassword() - try: - import psycopg2 - except ImportError: - print "You need psycopg2 installed to update a postgres DB." - return 1 - - try: - db_conn = psycopg2.connect(database = options.postgres_db_name, - user = options.username, - host = options.db_host, - password = password) - except Exception, e: - print "Could not connect to PostgreSQL database:", str(e) - return 1 - - if askForConfirmation(options.postgres_db_name, options.tablename): - doPostgreSQLUpgrade(db_conn, nonce_table_name=options.tablename) - - if options.mysql_db_name: - if not options.username: - print "A username is required to open a MySQL Database." - return 1 - password = askForPassword() - try: - import MySQLdb - except ImportError: - print "You must have MySQLdb installed to update a MySQL DB." - return 1 - - try: - db_conn = MySQLdb.connect(options.db_host, options.username, - password, options.mysql_db_name) - except Exception, e: - print "Could not connect to MySQL database:", str(e) - return 1 - - if askForConfirmation(options.mysql_db_name, options.tablename): - doMySQLUpgrade(db_conn, nonce_table_name=options.tablename) - - if db_conn: - db_conn.close() - else: - parser.print_help() - - return 0 - - -if __name__ == '__main__': - retval = main() - sys.exit(retval) diff --git a/examples/consumer.py b/examples/consumer.py index 936733e12017f90cd4773122672a47366b70715c..358854f3ffd33ae89e902fd5973dfcc77c651332 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -14,9 +14,11 @@ import urllib.parse import cgitb import sys + def quoteattr(s): qs = cgi.escape(s, 1) - return '"%s"' % (qs,) + return '"%s"' % (qs, ) + from http.server import HTTPServer, BaseHTTPRequestHandler @@ -43,13 +45,14 @@ from openid.extensions import pape, sreg # Used with an OpenID provider affiliate program. OPENID_PROVIDER_NAME = 'MyOpenID' -OPENID_PROVIDER_URL ='https://www.myopenid.com/affiliate_signup?affiliate_id=39' +OPENID_PROVIDER_URL = 'https://www.myopenid.com/affiliate_signup?affiliate_id=39' class OpenIDHTTPServer(HTTPServer): """http server that contains a reference to an OpenID consumer and knows its base URL. """ + def __init__(self, store, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) self.sessions = {} @@ -59,7 +62,8 @@ class OpenIDHTTPServer(HTTPServer): self.base_url = ('http://%s:%s/' % (self.server_name, self.server_port)) else: - self.base_url = 'http://%s/' % (self.server_name,) + self.base_url = 'http://%s/' % (self.server_name, ) + class OpenIDRequestHandler(BaseHTTPRequestHandler): """Request handler that knows how to verify an OpenID identity.""" @@ -150,7 +154,8 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): self.send_header('Content-type', 'text/html') self.setSessionCookie() self.end_headers() - self.wfile.write(bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) + self.wfile.write( + bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) def doVerify(self): """Process the form submission, initating OpenID verification. @@ -159,27 +164,30 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # First, make sure that the user entered something openid_url = self.query.get('openid_identifier') if not openid_url: - self.render('Enter an OpenID Identifier to verify.', - css_class='error', form_contents=openid_url) + self.render( + 'Enter an OpenID Identifier to verify.', + css_class='error', + form_contents=openid_url) return immediate = 'immediate' in self.query use_sreg = 'use_sreg' in self.query use_pape = 'use_pape' in self.query use_stateless = 'use_stateless' in self.query - oidconsumer = self.getConsumer(stateless = use_stateless) + oidconsumer = self.getConsumer(stateless=use_stateless) try: request = oidconsumer.begin(openid_url) except consumer.DiscoveryFailure as exc: fetch_error_string = 'Error in discovery: %s' % ( cgi.escape(str(exc))) - self.render(fetch_error_string, - css_class='error', - form_contents=openid_url) + self.render( + fetch_error_string, + css_class='error', + form_contents=openid_url) else: if request is None: msg = 'No OpenID services found for <code>%s</code>' % ( - cgi.escape(openid_url),) + cgi.escape(openid_url), ) self.render(msg, css_class='error', form_contents=openid_url) else: # Then, ask the library to begin the authorization. @@ -203,8 +211,9 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): self.end_headers() else: form_html = request.htmlMarkup( - trust_root, return_to, - form_tag_attrs={'id':'openid_message'}, + trust_root, + return_to, + form_tag_attrs={'id': 'openid_message'}, immediate=immediate) self.wfile.write(bytes(form_html, 'utf-8')) @@ -227,7 +236,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # us. Status is a code indicating the response type. info is # either None or a string containing more information about # the return type. - url = 'http://'+self.headers.get('Host')+self.path + url = 'http://' + self.headers.get('Host') + self.path info = oidconsumer.complete(self.query, url) sreg_resp = None @@ -240,8 +249,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # URL that we were verifying. We include it in the error # message to help the user figure out what happened. fmt = "Verification of %s failed: %s" - message = fmt % (cgi.escape(display_identifier), - info.message) + message = fmt % (cgi.escape(display_identifier), info.message) elif info.status == consumer.SUCCESS: # Success means that the transaction completed without # error. If info is None, it means that the user cancelled @@ -252,7 +260,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # was a real application, we would do our login, # comment posting, etc. here. fmt = "You have successfully verified %s as your identity." - message = fmt % (cgi.escape(display_identifier),) + message = fmt % (cgi.escape(display_identifier), ) sreg_resp = sreg.SRegResponse.fromSuccessResponse(info) pape_resp = pape.Response.fromSuccessResponse(info) if info.endpoint.canonicalID: @@ -261,14 +269,14 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # way their account with you is not compromised if their # i-name registration expires and is bought by someone else. message += (" This is an i-name, and its persistent ID is %s" - % (cgi.escape(info.endpoint.canonicalID),)) + % (cgi.escape(info.endpoint.canonicalID), )) elif info.status == consumer.CANCEL: # cancelled message = 'Verification cancelled' elif info.status == consumer.SETUP_NEEDED: if info.setup_url: message = '<a href=%s>Setup needed</a>' % ( - quoteattr(info.setup_url),) + quoteattr(info.setup_url), ) else: # This means auth didn't succeed, but you're welcome to try # non-immediate mode. @@ -280,8 +288,12 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # information in a log. message = 'Verification failed.' - self.render(message, css_class, display_identifier, - sreg_data=sreg_resp, pape_data=pape_resp) + self.render( + message, + css_class, + display_identifier, + sreg_data=sreg_resp, + pape_data=pape_resp) def doAffiliate(self): """Direct the user sign up with an affiliate OpenID provider.""" @@ -294,20 +306,26 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): def renderSREG(self, sreg_data): if not sreg_data: - self.wfile.write(bytes('<div class="alert">No registration data was returned</div>', 'utf-8')) + self.wfile.write( + bytes( + '<div class="alert">No registration data was returned</div>', + 'utf-8')) else: sreg_list = list(sreg_data.items()) sreg_list.sort() - self.wfile.write(bytes('<h2>Registration Data</h2>' - '<table class="sreg">' - '<thead><tr><th>Field</th><th>Value</th></tr></thead>' - '<tbody>', 'utf-8')) + self.wfile.write( + bytes('<h2>Registration Data</h2>' + '<table class="sreg">' + '<thead><tr><th>Field</th><th>Value</th></tr></thead>' + '<tbody>', 'utf-8')) odd = ' class="odd"' for k, v in sreg_list: field_name = sreg.data_fields.get(k, k) value = cgi.escape(v.encode('UTF-8')) - self.wfile.write(bytes('<tr%s><td>%s</td><td>%s</td></tr>' % (odd, field_name, value), 'utf-8')) + self.wfile.write( + bytes('<tr%s><td>%s</td><td>%s</td></tr>' % ( + odd, field_name, value), 'utf-8')) if odd: odd = '' else: @@ -317,15 +335,22 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): def renderPAPE(self, pape_data): if not pape_data: - self.wfile.write(bytes('<div class="alert">No PAPE data was returned</div>', 'utf-8')) + self.wfile.write( + bytes('<div class="alert">No PAPE data was returned</div>', + 'utf-8')) else: - self.wfile.write(bytes('<div class="alert">Effective Auth Policies<ul>', 'utf-8')) + self.wfile.write( + bytes('<div class="alert">Effective Auth Policies<ul>', + 'utf-8')) for policy_uri in pape_data.auth_policies: - self.wfile.write(bytes('<li><tt>%s</tt></li>' % (cgi.escape(policy_uri),), 'utf-8')) + self.wfile.write( + bytes('<li><tt>%s</tt></li>' % (cgi.escape(policy_uri), ), + 'utf-8')) if not pape_data.auth_policies: - self.wfile.write(bytes('<li>No policies were applied.</li>', 'utf-8')) + self.wfile.write( + bytes('<li>No policies were applied.</li>', 'utf-8')) self.wfile.write(bytes('</ul></div>', 'utf-8')) @@ -338,20 +363,26 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): def notFound(self): """Render a page with a 404 return code and a message.""" fmt = 'The path <q>%s</q> was not understood by this server.' - msg = fmt % (self.path,) + msg = fmt % (self.path, ) openid_url = self.query.get('openid_identifier') self.render(msg, 'error', openid_url, status=404) - def render(self, message=None, css_class='alert', form_contents=None, - status=200, title="Python OpenID Consumer Example", - sreg_data=None, pape_data=None): + def render(self, + message=None, + css_class='alert', + form_contents=None, + status=200, + title="Python OpenID Consumer Example", + sreg_data=None, + pape_data=None): """Render a page.""" self.send_response(status) self.send_header("Content-type", "text/html") self.end_headers() self.pageHeader(title) if message: - self.wfile.write(("<div class='%s'>" % (css_class,)).encode('utf-8')) + self.wfile.write( + ("<div class='%s'>" % (css_class, )).encode('utf-8')) self.wfile.write(message.encode('utf-8')) self.wfile.write("</div>".encode('utf-8')) @@ -366,7 +397,8 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): def pageHeader(self, title): """Render the page header""" self.setSessionCookie() - self.wfile.write(bytes('''<html> + self.wfile.write( + bytes('''<html> <head><title>%s</title></head> <style type="text/css"> * { @@ -425,7 +457,8 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): """Render the page footer""" if not form_contents: form_contents = '' - self.wfile.write(bytes('''\ + self.wfile.write( + bytes('''\ <div id="verify-form"> <form method="get" accept-charset="UTF-8" action=%s> Identifier: @@ -441,6 +474,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): </html> ''' % (quoteattr(self.buildURL('verify')), quoteattr(form_contents)), 'UTF-8')) + def main(host, port, data_path, weak_ssl=False): # Instantiate OpenID consumer store and OpenID consumer. If you # were connecting to a database, you would create the database @@ -460,6 +494,7 @@ def main(host, port, data_path, weak_ssl=False): print(server.base_url) server.serve_forever() + if __name__ == '__main__': host = 'localhost' port = 8001 @@ -468,24 +503,37 @@ if __name__ == '__main__': try: import optparse except ImportError: - pass # Use defaults (for Python 2.2) + pass # Use defaults (for Python 2.2) else: parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( - '-d', '--data-path', dest='data_path', + '-d', + '--data-path', + dest='data_path', help='Data directory for storing OpenID consumer state. ' 'Setting this option implies using a "FileStore."') parser.add_option( - '-p', '--port', dest='port', type='int', default=port, + '-p', + '--port', + dest='port', + type='int', + default=port, help='Port on which to listen for HTTP requests. ' 'Defaults to port %default.') parser.add_option( - '-s', '--host', dest='host', default=host, + '-s', + '--host', + dest='host', + default=host, help='Host on which to listen for HTTP requests. ' 'Also used for generating URLs. Defaults to %default.') parser.add_option( - '-w', '--weakssl', dest='weakssl', default=False, - action='store_true', help='Skip ssl cert verification') + '-w', + '--weakssl', + dest='weakssl', + default=False, + action='store_true', + help='Skip ssl cert verification') options, args = parser.parse_args() if args: diff --git a/examples/discover b/examples/discover index 9b74e8a02f0cdd7c6d7db094daf9305828f77a79..f279a26935c8286ae1ca7590485a3b1a58634741 100644 --- a/examples/discover +++ b/examples/discover @@ -2,17 +2,19 @@ from openid.consumer.discover import discover, DiscoveryFailure from openid.fetchers import HTTPFetchingError -names = [["server_url", "Server URL "], - ["local_id", "Local ID "], - ["canonicalID", "Canonical ID"], - ] +names = [ + ["server_url", "Server URL "], + ["local_id", "Local ID "], + ["canonicalID", "Canonical ID"], +] + def show_services(user_input, normalized, services): print " Claimed identifier:", normalized if services: print " Discovered OpenID services:" for n, service in enumerate(services): - print " %s." % (n,) + print " %s." % (n, ) for attr, name in names: val = getattr(service, attr, None) if val is not None: @@ -28,6 +30,7 @@ def show_services(user_input, normalized, services): print " No OpenID services found" print + if __name__ == "__main__": import sys diff --git a/examples/djopenid/consumer/urls.py b/examples/djopenid/consumer/urls.py index d55e056cf89269433bafe485a2c242e7596d15f3..797c4b6d5b712e6b8d68001ddff1f23ae2250ebc 100644 --- a/examples/djopenid/consumer/urls.py +++ b/examples/djopenid/consumer/urls.py @@ -1,9 +1,7 @@ - from django.conf.urls.defaults import * urlpatterns = patterns( 'djopenid.consumer.views', (r'^$', 'startOpenID'), (r'^finish/$', 'finishOpenID'), - (r'^xrds/$', 'rpXRDS'), -) + (r'^xrds/$', 'rpXRDS'), ) diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index aaa4f963672409791da7f52ab580baecad665ebd..000bee4123016ca43ce87e02bed8f9f507750bce 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,4 +1,3 @@ - from django import http from django.http import HttpResponseRedirect from django.views.generic.base import TemplateView @@ -15,11 +14,10 @@ PAPE_POLICIES = [ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] # List of (name, uri) for use in generating the request form. -POLICY_PAIRS = [(p, getattr(pape, p)) - for p in PAPE_POLICIES] +POLICY_PAIRS = [(p, getattr(pape, p)) for p in PAPE_POLICIES] def getOpenIDStore(): @@ -41,8 +39,7 @@ def renderIndexPage(request, **template_args): template_args['consumer_url'] = util.getViewURL(request, startOpenID) template_args['pape_policies'] = POLICY_PAIRS - response = TemplateView( - request, 'consumer/index.html', template_args) + response = TemplateView(request, 'consumer/index.html', template_args) response[YADIS_HEADER_NAME] = util.getViewURL(request, rpXRDS) return response @@ -73,7 +70,7 @@ def startOpenID(request): auth_request = c.begin(openid_url) except DiscoveryFailure as e: # Some other protocol-level failure occurred. - error = "OpenID discovery error: %s" % (str(e),) + error = "OpenID discovery error: %s" % (str(e), ) if error: # Render the page with an error. @@ -83,8 +80,8 @@ def startOpenID(request): # are optional, some are required. It's possible that the # server doesn't support sreg or won't return any of the # fields. - sreg_request = sreg.SRegRequest(optional=['email', 'nickname'], - required=['dob']) + sreg_request = sreg.SRegRequest( + optional=['email', 'nickname'], required=['dob']) auth_request.addExtension(sreg_request) # Add Attribute Exchange request information. @@ -92,11 +89,12 @@ def startOpenID(request): # XXX - uses myOpenID-compatible schema values, which are # not those listed at axschema.org. ax_request.add( - ax.AttrInfo('http://schema.openid.net/namePerson', - required=True)) + ax.AttrInfo('http://schema.openid.net/namePerson', required=True)) ax_request.add( - ax.AttrInfo('http://schema.openid.net/contact/web/default', - required=False, count=ax.UNLIMITED_VALUES)) + ax.AttrInfo( + 'http://schema.openid.net/contact/web/default', + required=False, + count=ax.UNLIMITED_VALUES)) auth_request.addExtension(ax_request) # Add PAPE request information. We'll ask for @@ -130,10 +128,10 @@ def startOpenID(request): # users will have to click the form submit button to # initiate OpenID authentication. form_id = 'openid_message' - form_html = auth_request.formMarkup(trust_root, return_to, - False, {'id': form_id}) - return TemplateView( - request, 'consumer/request_form.html', {'html': form_html}) + form_html = auth_request.formMarkup(trust_root, return_to, False, + {'id': form_id}) + return TemplateView(request, 'consumer/request_form.html', + {'html': form_html}) return renderIndexPage(request) @@ -174,11 +172,12 @@ def finishOpenID(request): ax_response = ax.FetchResponse.fromSuccessResponse(response) if ax_response: ax_items = { - 'fullname': ax_response.get( - 'http://schema.openid.net/namePerson'), - 'web': ax_response.get( + 'fullname': + ax_response.get('http://schema.openid.net/namePerson'), + 'web': + ax_response.get( 'http://schema.openid.net/contact/web/default'), - } + } # Get a PAPE response object if response information was # included in the OpenID response. @@ -191,18 +190,19 @@ def finishOpenID(request): # Map different consumer status codes to template contexts. results = { - consumer.CANCEL: - {'message': 'OpenID authentication cancelled.'}, - - consumer.FAILURE: - {'error': 'OpenID authentication failed.'}, - - consumer.SUCCESS: - {'url': response.getDisplayIdentifier(), - 'sreg': sreg_response and list(sreg_response.items()), - 'ax': list(ax_items.items()), - 'pape': pape_response} + consumer.CANCEL: { + 'message': 'OpenID authentication cancelled.' + }, + consumer.FAILURE: { + 'error': 'OpenID authentication failed.' + }, + consumer.SUCCESS: { + 'url': response.getDisplayIdentifier(), + 'sreg': sreg_response and list(sreg_response.items()), + 'ax': list(ax_items.items()), + 'pape': pape_response } + } result = results[response.status] @@ -220,7 +220,5 @@ def rpXRDS(request): """ Return a relying party verification XRDS document """ - return util.renderXRDS( - request, - [RP_RETURN_TO_URL_TYPE], - [util.getViewURL(request, finishOpenID)]) + return util.renderXRDS(request, [RP_RETURN_TO_URL_TYPE], + [util.getViewURL(request, finishOpenID)]) diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index 5e78ea979ea3846a4602f604e265fc4666beffac..eece40b215b9b7555b2c27fd37ad2b0a2d23cf85 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -1,10 +1,12 @@ #!/usr/bin/env python from django.core.management import execute_manager try: - import settings # Assumed to be in the same directory. + import settings # Assumed to be in the same directory. except ImportError: import sys - sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) + sys.stderr.write( + "Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" + % __file__) sys.exit(1) if __name__ == "__main__": diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index de5f1e924cef06c528854246f17d08f91ea468e2..b24edf69e04b76a078fc58913e0c6014f3b17d3b 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -1,4 +1,3 @@ - from django.test.testcases import TestCase from djopenid.server import views from djopenid import util @@ -29,11 +28,15 @@ class TestProcessTrustResult(TestCase): # Set up the OpenID request we're responding to. op_endpoint = 'http://127.0.0.1:8080/endpoint' message = Message.fromPostArgs({ - 'openid.mode': 'checkid_setup', - 'openid.identity': id_url, - 'openid.return_to': 'http://127.0.0.1/%s' % (self.id(),), - 'openid.sreg.required': 'postcode', - }) + 'openid.mode': + 'checkid_setup', + 'openid.identity': + id_url, + 'openid.return_to': + 'http://127.0.0.1/%s' % (self.id(), ), + 'openid.sreg.required': + 'postcode', + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) @@ -70,11 +73,15 @@ class TestShowDecidePage(TestCase): # Set up the OpenID request we're responding to. op_endpoint = 'http://127.0.0.1:8080/endpoint' message = Message.fromPostArgs({ - 'openid.mode': 'checkid_setup', - 'openid.identity': id_url, - 'openid.return_to': 'http://unreachable.invalid/%s' % (self.id(),), - 'openid.sreg.required': 'postcode', - }) + 'openid.mode': + 'checkid_setup', + 'openid.identity': + id_url, + 'openid.return_to': + 'http://unreachable.invalid/%s' % (self.id(), ), + 'openid.sreg.required': + 'postcode', + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) @@ -96,7 +103,7 @@ class TestGenericXRDS(TestCase): response = util.renderXRDS(request, type_uris, [endpoint_url]) requested_url = 'http://requested.invalid/' - (endpoint,) = applyFilter(requested_url, response.content) + (endpoint, ) = applyFilter(requested_url, response.content) self.assertEqual(YADIS_CONTENT_TYPE, response['Content-Type']) self.assertEqual(type_uris, endpoint.type_uris) diff --git a/examples/djopenid/server/urls.py b/examples/djopenid/server/urls.py index d6931a4dec94471c2a7e570d58eea366c83432eb..52a3661c94723f057db7f3243b596f7744e2f211 100644 --- a/examples/djopenid/server/urls.py +++ b/examples/djopenid/server/urls.py @@ -1,4 +1,3 @@ - from django.conf.urls.defaults import * urlpatterns = patterns( @@ -8,5 +7,4 @@ urlpatterns = patterns( (r'^processTrustResult/$', 'processTrustResult'), (r'^user/$', 'idPage'), (r'^endpoint/$', 'endpoint'), - (r'^trust/$', 'trustPage'), -) + (r'^trust/$', 'trustPage'), ) diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 61c896bb8c77757af8e7b85abafc01f5761f027b..f0f15f1e4b48fdc190b7141ddf2394b8a7c173cc 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -1,4 +1,3 @@ - """ This module implements an example server for the OpenID library. Some functionality has been omitted intentionally; this code is intended to @@ -33,6 +32,7 @@ from openid.extensions import sreg from openid.extensions import pape from openid.fetchers import HTTPFetchingError + def getOpenIDStore(): """ Return an OpenID store object fit for the currently-chosen @@ -40,12 +40,14 @@ def getOpenIDStore(): """ return util.getOpenIDStore('/tmp/djopenid_s_store', 's_') + def getServer(request): """ Get a Server object to perform OpenID authentication. """ return Server(getOpenIDStore(), getViewURL(request, endpoint)) + def setRequest(request, openid_request): """ Store the openid request information in the session. @@ -55,40 +57,44 @@ def setRequest(request, openid_request): else: request.session['openid_request'] = None + def getRequest(request): """ Get an openid request from the session, if any. """ return request.session.get('openid_request') + def server(request): """ Respond to requests for the server's primary web page. """ return render_to_response( - 'server/index.html', - {'user_url': getViewURL(request, idPage), - 'server_xrds_url': getViewURL(request, idpXrds), - }, + 'server/index.html', { + 'user_url': getViewURL(request, idPage), + 'server_xrds_url': getViewURL(request, idpXrds), + }, context_instance=RequestContext(request)) + def idpXrds(request): """ Respond to requests for the IDP's XRDS document, which is used in IDP-driven identifier selection. """ - return util.renderXRDS( - request, [OPENID_IDP_2_0_TYPE], [getViewURL(request, endpoint)]) + return util.renderXRDS(request, [OPENID_IDP_2_0_TYPE], + [getViewURL(request, endpoint)]) + def idPage(request): """ Serve the identity page for OpenID URLs. """ return render_to_response( - 'server/idPage.html', - {'server_url': getViewURL(request, endpoint)}, + 'server/idPage.html', {'server_url': getViewURL(request, endpoint)}, context_instance=RequestContext(request)) + def trustPage(request): """ Display the trust page template, which allows the user to decide @@ -96,9 +102,10 @@ def trustPage(request): """ return render_to_response( 'server/trust.html', - {'trust_handler_url':getViewURL(request, processTrustResult)}, + {'trust_handler_url': getViewURL(request, processTrustResult)}, context_instance=RequestContext(request)) + def endpoint(request): """ Respond to low-level OpenID protocol messages. @@ -114,16 +121,14 @@ def endpoint(request): except ProtocolError as why: # This means the incoming request was invalid. return render_to_response( - 'server/endpoint.html', - {'error': str(why)}, - context_instance=RequestContext(request)) + 'server/endpoint.html', {'error': str(why)}, + context_instance=RequestContext(request)) # If we did not get a request, display text indicating that this # is an endpoint. if openid_request is None: return render_to_response( - 'server/endpoint.html', - {}, + 'server/endpoint.html', {}, context_instance=RequestContext(request)) # We got a request; if the mode is checkid_*, we will handle it by @@ -136,6 +141,7 @@ def endpoint(request): openid_response = s.handleRequest(openid_request) return displayResponse(request, openid_response) + def handleCheckIDRequest(request, openid_request): """ Handle checkid_* requests. Get input from the user to find out @@ -157,9 +163,8 @@ def handleCheckIDRequest(request, openid_request): if id_url != openid_request.identity: # Return an error response error_response = ProtocolError( - openid_request.message, - "This server cannot verify the URL %r" % - (openid_request.identity,)) + openid_request.message, "This server cannot verify the URL %r" + % (openid_request.identity, )) return displayResponse(request, error_response) @@ -177,6 +182,7 @@ def handleCheckIDRequest(request, openid_request): setRequest(request, openid_request) return showDecidePage(request, openid_request) + def showDecidePage(request, openid_request): """ Render a page to the user so a trust decision can be made. @@ -198,14 +204,15 @@ def showDecidePage(request, openid_request): pape_request = pape.Request.fromOpenIDRequest(openid_request) return render_to_response( - 'server/trust.html', - {'trust_root': trust_root, - 'trust_handler_url':getViewURL(request, processTrustResult), - 'trust_root_valid': trust_root_valid, - 'pape_request': pape_request, - }, + 'server/trust.html', { + 'trust_root': trust_root, + 'trust_handler_url': getViewURL(request, processTrustResult), + 'trust_root_valid': trust_root_valid, + 'pape_request': pape_request, + }, context_instance=RequestContext(request)) + def processTrustResult(request): """ Handle the result of a trust decision and respond to the RP @@ -223,8 +230,8 @@ def processTrustResult(request): allowed = 'allow' in request.POST # Generate a response with the appropriate answer. - openid_response = openid_request.answer(allowed, - identity=response_identity) + openid_response = openid_request.answer( + allowed, identity=response_identity) # Send Simple Registration data in the response, if appropriate. if allowed: @@ -238,7 +245,7 @@ def processTrustResult(request): 'country': 'ES', 'language': 'eu', 'timezone': 'America/New_York', - } + } sreg_req = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) @@ -250,6 +257,7 @@ def processTrustResult(request): return displayResponse(request, openid_response) + def displayResponse(request, openid_response): """ Display an OpenID response. Errors will be displayed directly to @@ -266,9 +274,8 @@ def displayResponse(request, openid_response): # If it couldn't be encoded, display an error. text = why.response.encodeToKVForm() return render_to_response( - 'server/endpoint.html', - {'error': cgi.escape(text)}, - context_instance=RequestContext(request)) + 'server/endpoint.html', {'error': cgi.escape(text)}, + context_instance=RequestContext(request)) # Construct the appropriate django framework response. r = http.HttpResponse(webresponse.body) diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index b1ccc2699a05f5baaa9815c0c7162a4eac9c7212..b9bf06e9c08075a4196ac8be08713c2b8c9320f5 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -15,8 +15,7 @@ except ImportError as e: DEBUG = True TEMPLATE_DEBUG = DEBUG -ADMINS = ( - # ('Your Name', 'your_email@domain.com'), +ADMINS = ( # ('Your Name', 'your_email@domain.com'), ) MANAGERS = ADMINS @@ -59,30 +58,25 @@ ADMIN_MEDIA_PREFIX = '/media/' SECRET_KEY = 'u^bw6lmsa6fah0$^lz-ct$)y7x7#ag92-z+y45-8!(jk0lkavy' # List of callables that know how to import templates from various sources. -TEMPLATE_LOADERS = ( - 'django.template.loaders.filesystem.Loader', - 'django.template.loaders.app_directories.Loader', -) +TEMPLATE_LOADERS = ('django.template.loaders.filesystem.Loader', + 'django.template.loaders.app_directories.Loader', ) MIDDLEWARE_CLASSES = ( 'django.middleware.common.CommonMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.middleware.doc.XViewMiddleware', -) + 'django.middleware.doc.XViewMiddleware', ) ROOT_URLCONF = 'djopenid.urls' TEMPLATE_CONTEXT_PROCESSORS = () TEMPLATE_DIRS = ( - os.path.abspath(os.path.join(os.path.dirname(__file__), 'templates')), -) + os.path.abspath(os.path.join(os.path.dirname(__file__), 'templates')), ) INSTALLED_APPS = ( 'django.contrib.contenttypes', 'django.contrib.sessions', # These are the example OpenID consumer and server 'djopenid.consumer', - 'djopenid.server', -) + 'djopenid.server', ) diff --git a/examples/djopenid/urls.py b/examples/djopenid/urls.py index d91ee1f1d06d84186c39cb4a179047fd24bfd122..0b1f3c6ca5b8f877644ddca1f57a7bc8d8c4990a 100644 --- a/examples/djopenid/urls.py +++ b/examples/djopenid/urls.py @@ -4,5 +4,4 @@ urlpatterns = patterns( '', ('^$', 'djopenid.views.index'), ('^consumer/', include('djopenid.consumer.urls')), - ('^server/', include('djopenid.server.urls')), -) + ('^server/', include('djopenid.server.urls')), ) diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index fc21ceb5c78397e9f1f55bf9436e4bf2067f2d54..2e6a068f32438a2606d46d5ec7accce2f4135cf5 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -1,4 +1,3 @@ - """ Utility code for the Django example consumer and server. """ @@ -60,13 +59,13 @@ def getOpenIDStore(filestore_path, table_prefix): tablenames = { 'associations_table': table_prefix + 'openid_associations', 'nonces_table': table_prefix + 'openid_nonces', - } + } types = { 'django.db.backends.postgresql_psycopg2': sqlstore.PostgreSQLStore, 'django.db.backends.mysql': sqlstore.MySQLStore, 'django.db.backends.sqlite3': sqlstore.SQLiteStore, - } + } if db_engine not in types: raise ImproperlyConfigured( @@ -127,7 +126,7 @@ def getBaseURL(req): if port in [80, 443] or not port: port = '' else: - port = ':%s' % (port,) + port = ':%s' % (port, ) url = "%s://%s%s/" % (proto, name, port) return url @@ -150,9 +149,9 @@ def renderXRDS(request, type_uris, endpoint_urls): URLs in one service block, and return a response with the appropriate content-type. """ - response = render_to_response('xrds.xml', - {'type_uris': type_uris, - 'endpoint_urls': endpoint_urls}, - context_instance=RequestContext(request)) + response = render_to_response( + 'xrds.xml', {'type_uris': type_uris, + 'endpoint_urls': endpoint_urls}, + context_instance=RequestContext(request)) response['Content-Type'] = YADIS_CONTENT_TYPE return response diff --git a/examples/djopenid/views.py b/examples/djopenid/views.py index ddc9d259fb559395a4cd0a9c88e133d62e810f0b..d53015e83a3eec0ec8f462c393315cfe80c98c28 100644 --- a/examples/djopenid/views.py +++ b/examples/djopenid/views.py @@ -1,13 +1,13 @@ - from djopenid import util from django.views.generic.base import TemplateView def index(request): - consumer_url = util.getViewURL( - request, 'djopenid.consumer.views.startOpenID') + consumer_url = util.getViewURL(request, + 'djopenid.consumer.views.startOpenID') server_url = util.getViewURL(request, 'djopenid.server.views.server') - return TemplateView(request, 'index.html', - {'consumer_url': consumer_url, - 'server_url': server_url}) + return TemplateView(request, 'index.html', { + 'consumer_url': consumer_url, + 'server_url': server_url + }) diff --git a/examples/server.py b/examples/server.py index a50d53a158334e8c3a642e1de3c02ee2efcda0de..603b1a1dff5ca2656fd550162c3b11ab5418af52 100644 --- a/examples/server.py +++ b/examples/server.py @@ -11,9 +11,11 @@ import cgi import cgitb import sys + def quoteattr(s): qs = cgi.escape(s, 1) - return '"%s"' % (qs,) + return '"%s"' % (qs, ) + try: import openid @@ -33,11 +35,13 @@ from openid.server import server from openid.store.filestore import FileOpenIDStore from openid.consumer import discover + class OpenIDHTTPServer(HTTPServer): """ http server that contains a reference to an OpenID Server and knows its base URL. """ + def __init__(self, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) @@ -45,7 +49,7 @@ class OpenIDHTTPServer(HTTPServer): self.base_url = ('http://%s:%s/' % (self.server_name, self.server_port)) else: - self.base_url = 'http://%s/' % (self.server_name,) + self.base_url = 'http://%s/' % (self.server_name, ) self.openid = None self.approved = {} @@ -60,7 +64,6 @@ class ServerHandler(BaseHTTPRequestHandler): self.user = None BaseHTTPRequestHandler.__init__(self, *args, **kwargs) - def do_GET(self): try: self.parsed_uri = urlparse(self.path) @@ -97,7 +100,8 @@ class ServerHandler(BaseHTTPRequestHandler): self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() - self.wfile.write(bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) + self.wfile.write( + bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) def do_POST(self): try: @@ -127,7 +131,8 @@ class ServerHandler(BaseHTTPRequestHandler): self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() - self.wfile.write(bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) + self.wfile.write( + bytes(cgitb.html(sys.exc_info(), context=10), 'utf-8')) def handleAllow(self, query): # pretend this next bit is keying off the user's session or something, @@ -153,11 +158,10 @@ class ServerHandler(BaseHTTPRequestHandler): response = request.answer(False) else: - assert False, 'strange allow post. %r' % (query,) + assert False, 'strange allow post. %r' % (query, ) self.displayResponse(response) - def setUser(self): cookies = self.headers.get('Cookie') if cookies: @@ -200,9 +204,7 @@ class ServerHandler(BaseHTTPRequestHandler): # In a real application, this data would be user-specific, # and the user should be asked for permission to release # it. - sreg_data = { - 'nickname':self.user - } + sreg_data = {'nickname': self.user} sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) response.addExtension(sreg_resp) @@ -251,7 +253,7 @@ class ServerHandler(BaseHTTPRequestHandler): elif 'cancel' in self.query: self.redirect(self.query['fail_to']) else: - assert 0, 'strange login %r' % (self.query,) + assert 0, 'strange login %r' % (self.query, ) def redirect(self, url): self.send_response(302) @@ -263,8 +265,7 @@ class ServerHandler(BaseHTTPRequestHandler): def writeUserHeader(self): if self.user is None: t1970 = time.gmtime(0) - expires = time.strftime( - 'Expires=%a, %d-%b-%y %H:%M:%S GMT', t1970) + expires = time.strftime('Expires=%a, %d-%b-%y %H:%M:%S GMT', t1970) self.send_header('Set-Cookie', 'user=;%s' % expires) else: self.send_header('Set-Cookie', 'user=%s' % self.user) @@ -285,20 +286,26 @@ class ServerHandler(BaseHTTPRequestHandler): ('http://www.openidenabled.com/', 'An OpenID community Web site, home of this library'), ('http://www.openid.net/', 'the official OpenID Web site'), - ] + ] resource_markup = ''.join([term(url, text) for url, text in resources]) - self.showPage(200, 'This is an OpenID server', msg="""\ + self.showPage( + 200, + 'This is an OpenID server', + msg="""\ <p>%s is an OpenID server endpoint.<p> <p>For more information about OpenID, see:</p> <dl> %s </dl> - """ % (link(endpoint_url), resource_markup,)) + """ % (link(endpoint_url), resource_markup, )) def showErrorPage(self, error_message): - self.showPage(400, 'Error Processing Request', err='''\ + self.showPage( + 400, + 'Error Processing Request', + err='''\ <p>%s</p> <!-- @@ -334,14 +341,14 @@ class ServerHandler(BaseHTTPRequestHandler): ''' % error_message) def showDecidePage(self, request): - id_url_base = self.server.base_url+'id/' + id_url_base = self.server.base_url + 'id/' # XXX: This may break if there are any synonyms for id_url_base, # such as referring to it by IP address or a CNAME. - assert (request.identity.startswith(id_url_base) or + assert (request.identity.startswith(id_url_base) or request.idSelect()), repr((request.identity, id_url_base)) expected_user = request.identity[len(id_url_base):] - if request.idSelect(): # We are being asked to select an ID + if request.idSelect(): # We are being asked to select an ID msg = '''\ <p>A site has asked for your identity. You may select an identifier by which you would like this site to know you. @@ -353,7 +360,7 @@ class ServerHandler(BaseHTTPRequestHandler): fdata = { 'id_url_base': id_url_base, 'trust_root': request.trust_root, - } + } form = '''\ <form method="POST" action="/allow"> <table> @@ -368,7 +375,7 @@ class ServerHandler(BaseHTTPRequestHandler): <input type="submit" name="yes" value="yes" /> <input type="submit" name="no" value="no" /> </form> - '''%fdata + ''' % fdata elif expected_user == self.user: msg = '''\ <p>A new site has asked to confirm your identity. If you @@ -380,7 +387,7 @@ class ServerHandler(BaseHTTPRequestHandler): fdata = { 'identity': request.identity, 'trust_root': request.trust_root, - } + } form = '''\ <table> <tr><td>Identity:</td><td>%(identity)s</td></tr> @@ -398,7 +405,7 @@ class ServerHandler(BaseHTTPRequestHandler): mdata = { 'expected_user': expected_user, 'user': self.user, - } + } msg = '''\ <p>A site has asked for an identity belonging to %(expected_user)s, but you are logged in as %(user)s. To @@ -410,7 +417,7 @@ class ServerHandler(BaseHTTPRequestHandler): 'identity': request.identity, 'trust_root': request.trust_root, 'expected_user': expected_user, - } + } form = '''\ <table> <tr><td>Identity:</td><td>%(identity)s</td></tr> @@ -450,7 +457,11 @@ class ServerHandler(BaseHTTPRequestHandler): else: msg = '' - self.showPage(200, 'An Identity Page', head_extras=disco_tags, msg='''\ + self.showPage( + 200, + 'An Identity Page', + head_extras=disco_tags, + msg='''\ <p>This is an identity page for %s.</p> %s ''' % (ident, msg)) @@ -462,7 +473,8 @@ class ServerHandler(BaseHTTPRequestHandler): endpoint_url = self.server.base_url + 'openidserver' user_url = self.server.base_url + 'id/' + user - self.wfile.write(bytes("""\ + self.wfile.write( + bytes("""\ <?xml version="1.0" encoding="UTF-8"?> <xrds:XRDS xmlns:xrds="xri://$xrds" @@ -478,8 +490,8 @@ class ServerHandler(BaseHTTPRequestHandler): </XRD> </xrds:XRDS> -"""%(discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, - endpoint_url, user_url), 'utf-8')) +""" % (discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, endpoint_url, + user_url), 'utf-8')) def showServerYadis(self): self.send_response(200) @@ -487,7 +499,8 @@ class ServerHandler(BaseHTTPRequestHandler): self.end_headers() endpoint_url = self.server.base_url + 'openidserver' - self.wfile.write(bytes("""\ + self.wfile.write( + bytes("""\ <?xml version="1.0" encoding="UTF-8"?> <xrds:XRDS xmlns:xrds="xri://$xrds" @@ -501,7 +514,7 @@ class ServerHandler(BaseHTTPRequestHandler): </XRD> </xrds:XRDS> -"""%(discover.OPENID_IDP_2_0_TYPE, endpoint_url,), 'utf-8')) +""" % (discover.OPENID_IDP_2_0_TYPE, endpoint_url, ), 'utf-8')) def showMainPage(self): yadis_tag = '<meta http-equiv="x-xrds-location" content="%s">'%\ @@ -519,7 +532,11 @@ class ServerHandler(BaseHTTPRequestHandler): order to simulate a standard Web user experience. You are not <a href='/login'>logged in</a>.</p>""" - self.showPage(200, 'Main Page', head_extras = yadis_tag, msg='''\ + self.showPage( + 200, + 'Main Page', + head_extras=yadis_tag, + msg='''\ <p>This is a simple OpenID server implemented using the <a href="http://openid.schtuff.com/">Python OpenID library</a>.</p> @@ -532,10 +549,14 @@ class ServerHandler(BaseHTTPRequestHandler): OpenID consumers outside of the firewall with it.</p> <p>The URL for this server is <a href=%s><tt>%s</tt></a>.</p> - ''' % (user_message, quoteattr(self.server.base_url), self.server.base_url)) + ''' % (user_message, quoteattr(self.server.base_url), + self.server.base_url)) def showLoginPage(self, success_to, fail_to): - self.showPage(200, 'Login Page', form='''\ + self.showPage( + 200, + 'Login Page', + form='''\ <h2>Login</h2> <p>You may log in with any name. This server does not use passwords because it is just a sample of how to use the OpenID @@ -549,8 +570,13 @@ class ServerHandler(BaseHTTPRequestHandler): </form> ''' % (success_to, fail_to)) - def showPage(self, response_code, title, - head_extras='', msg=None, err=None, form=None): + def showPage(self, + response_code, + title, + head_extras='', + msg=None, + err=None, + form=None): if self.user is None: user_link = '<a href="/login">not logged in</a>.' @@ -561,7 +587,7 @@ class ServerHandler(BaseHTTPRequestHandler): body = '' if err is not None: - body += '''\ + body += '''\ <div class="error"> %s </div> @@ -586,14 +612,15 @@ class ServerHandler(BaseHTTPRequestHandler): 'head_extras': head_extras, 'body': body, 'user_link': user_link, - } + } self.send_response(response_code) self.writeUserHeader() self.send_header('Content-type', 'text/html') self.end_headers() - self.wfile.write(bytes('''<html> + self.wfile.write( + bytes('''<html> <head> <title>%(title)s</title> %(head_extras)s @@ -669,15 +696,16 @@ class ServerHandler(BaseHTTPRequestHandler): </body> </html> ''' % contents, 'UTF-8')) - + def binaryToUTF8(self, data): args = {} - for key, value in data.items(): + for key, value in data.items(): key = key.decode('utf-8') value = value.decode('utf-8') args[key] = value return args + def main(host, port, data_path): addr = (host, port) httpserver = OpenIDHTTPServer(addr, ServerHandler) @@ -694,6 +722,7 @@ def main(host, port, data_path): print(httpserver.base_url) httpserver.serve_forever() + if __name__ == '__main__': host = 'localhost' data_path = 'sstore' @@ -702,19 +731,29 @@ if __name__ == '__main__': try: import optparse except ImportError: - pass # Use defaults (for Python 2.2) + pass # Use defaults (for Python 2.2) else: parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( - '-d', '--data-path', dest='data_path', default=data_path, + '-d', + '--data-path', + dest='data_path', + default=data_path, help='Data directory for storing OpenID consumer state. ' 'Defaults to "%default" in the current directory.') parser.add_option( - '-p', '--port', dest='port', type='int', default=port, + '-p', + '--port', + dest='port', + type='int', + default=port, help='Port on which to listen for HTTP requests. ' 'Defaults to port %default.') parser.add_option( - '-s', '--host', dest='host', default=host, + '-s', + '--host', + dest='host', + default=host, help='Host on which to listen for HTTP requests. ' 'Also used for generating URLs. Defaults to %default.') diff --git a/openid/__init__.py b/openid/__init__.py index 8577355396434f9a5c26a52b02e839e317be7745..465498f27a8c8ae93aa3ff296062b7d4b31e1c8b 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -1,4 +1,4 @@ -#-*-coding: utf-8-*- +#-*- coding: utf-8 -*- """ This package is an implementation of the OpenID specification in Python. It contains code for both server and consumer @@ -9,7 +9,7 @@ module. @contact: U{http://github.com/necaris/python3-openid/} -@copyright: (C) 2005-2008 JanRain, Inc., 2012-2013 Rami Chowdhury +@copyright: (C) 2005-2008 JanRain, Inc., 2012-2017 Rami Chowdhury @license: Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,7 +23,9 @@ module. and limitations under the License. """ -version_info = (3, 0, 9) +version_info = (3, 1, 0) + + __version__ = ".".join(str(x) for x in version_info) diff --git a/openid/association.py b/openid/association.py index fcba4bbe0048276e7612a1f93ba19c76081bd515..94d68d71ac0a18dcbe8fb12f9f99ff1127163eb0 100644 --- a/openid/association.py +++ b/openid/association.py @@ -1,5 +1,5 @@ #-*-test-case-name: openid.test.test_association-*- -#-*-coding: utf-8-*- +#-*- coding: utf-8 -*- """ This module contains code for dealing with associations between consumers and servers. Associations contain a shared secret that is @@ -24,6 +24,13 @@ association. does not support C{'no-encryption'} associations. It prefers HMAC-SHA1/DH-SHA1 association types if available. """ +import time +import functools + +from openid import cryptutil +from openid import kvform +from openid import oidutil +from openid.message import OPENID_NS __all__ = [ 'default_negotiator', @@ -32,13 +39,6 @@ __all__ = [ 'Association', ] -import time - -from openid import cryptutil -from openid import kvform -from openid import oidutil -from openid.message import OPENID_NS - all_association_types = [ 'HMAC-SHA1', 'HMAC-SHA256', @@ -76,7 +76,7 @@ def getSessionTypes(assoc_type): assoc_to_session = { 'HMAC-SHA1': ['DH-SHA1', 'no-encryption'], 'HMAC-SHA256': ['DH-SHA256', 'no-encryption'], - } + } return assoc_to_session.get(assoc_type, []) @@ -84,9 +84,8 @@ def checkSessionType(assoc_type, session_type): """Check to make sure that this pair of assoc type and session type are allowed""" if session_type not in getSessionTypes(assoc_type): - raise ValueError( - 'Session type %r not valid for assocation type %r' - % (session_type, assoc_type)) + raise ValueError('Session type %r not valid for assocation type %r' % + (session_type, assoc_type)) class SessionNegotiator(object): @@ -163,7 +162,7 @@ class SessionNegotiator(object): if not available: raise ValueError('No session available for association type %r' - % (assoc_type,)) + % (assoc_type, )) for session_type in getSessionTypes(assoc_type): self.addAllowedType(assoc_type, session_type) @@ -185,6 +184,7 @@ class SessionNegotiator(object): except IndexError: return (None, None) + default_negotiator = SessionNegotiator(default_association_order) encrypted_negotiator = SessionNegotiator(only_encrypted_association_order) @@ -195,9 +195,10 @@ def getSecretSize(assoc_type): elif assoc_type == 'HMAC-SHA256': return 32 else: - raise ValueError('Unsupported association type: %r' % (assoc_type,)) + raise ValueError('Unsupported association type: %r' % (assoc_type, )) +@functools.total_ordering class Association(object): """ This class represents an association between a server and a @@ -339,7 +340,7 @@ class Association(object): """ if assoc_type not in all_association_types: fmt = '%r is not a supported association type' - raise ValueError(fmt % (assoc_type,)) + raise ValueError(fmt % (assoc_type, )) # secret_size = getSecretSize(assoc_type) # if len(secret) != secret_size: @@ -373,6 +374,16 @@ class Association(object): return max(0, self.issued + self.lifetime - now) + def __lt__(self, other): + """ + Compare two C{L{Association}} instances to determine relative + ordering. + + Currently compares object lifetimes -- C{L{Association}} A < B + if A.lifetime < B.lifetime. + """ + return self.lifetime < other.lifetime + def __eq__(self, other): """ This checks to see if two C{L{Association}} instances @@ -472,8 +483,8 @@ class Association(object): try: mac = self._macs[self.assoc_type] except KeyError: - raise ValueError( - 'Unknown association type: %r' % (self.assoc_type,)) + raise ValueError('Unknown association type: %r' % + (self.assoc_type, )) return mac(self.secret, kv) @@ -500,7 +511,7 @@ class Association(object): @rtype: L{openid.message.Message} """ if (message.hasKey(OPENID_NS, 'sig') or - message.hasKey(OPENID_NS, 'signed')): + message.hasKey(OPENID_NS, 'signed')): raise ValueError('Message already has signed list or signature') extant_handle = message.getArg(OPENID_NS, 'assoc_handle') @@ -510,8 +521,7 @@ class Association(object): signed_message = message.copy() signed_message.setArg(OPENID_NS, 'assoc_handle', self.handle) message_keys = list(signed_message.toPostArgs().keys()) - signed_list = [k[7:] for k in message_keys - if k.startswith('openid.')] + signed_list = [k[7:] for k in message_keys if k.startswith('openid.')] signed_list.append('signed') signed_list.sort() signed_message.setArg(OPENID_NS, 'signed', ','.join(signed_list)) @@ -528,7 +538,7 @@ class Association(object): """ message_sig = message.getArg(OPENID_NS, 'sig') if not message_sig: - raise ValueError("%s has no sig." % (message,)) + raise ValueError("%s has no sig." % (message, )) calculated_sig = self.getMessageSignature(message) # remember, getMessageSignature returns bytes calculated_sig = calculated_sig.decode('utf-8') @@ -537,7 +547,7 @@ class Association(object): def _makePairs(self, message): signed = message.getArg(OPENID_NS, 'signed') if not signed: - raise ValueError('Message has no signed list: %s' % (message,)) + raise ValueError('Message has no signed list: %s' % (message, )) signed_list = signed.split(',') pairs = [] @@ -547,8 +557,6 @@ class Association(object): return pairs def __repr__(self): - return "<%s.%s %s %s>" % ( - self.__class__.__module__, - self.__class__.__name__, - self.assoc_type, - self.handle) + return "<%s.%s %s %s>" % (self.__class__.__module__, + self.__class__.__name__, self.assoc_type, + self.handle) diff --git a/openid/codecutil.py b/openid/codecutil.py index 90ee62e4e57e7d68fc799fe8997f7c040f7309f1..d8b9fe962a893a67c5643946f755e73b062b9cd9 100644 --- a/openid/codecutil.py +++ b/openid/codecutil.py @@ -8,11 +8,11 @@ except ValueError: (0xA0, 0xD7FF), (0xF900, 0xFDCF), (0xFDF0, 0xFFEF), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), - ] + ] else: UCSCHAR = [ (0xA0, 0xD7FF), @@ -32,13 +32,13 @@ else: (0xC0000, 0xCFFFD), (0xD0000, 0xDFFFD), (0xE1000, 0xEFFFD), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), (0xF0000, 0xFFFFD), (0x100000, 0x10FFFD), - ] + ] _ESCAPE_RANGES = UCSCHAR + IPRIVATE @@ -87,4 +87,5 @@ def _pct_escape_handler(err): replacements = _pct_encoded_replacements(chunk) return ("".join(replacements), err.end) + codecs.register_error("oid_percent_escape", _pct_escape_handler) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 62196fcd9c5fb1d3fdb050943e40c593c5b06782..c081621c676a32fc1551da01e50889f250502b86 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -206,11 +206,18 @@ from openid.store.nonce import mkNonce, split as splitNonce from openid.yadis.manager import Discovery from openid import urinorm - -__all__ = ['AuthRequest', 'Consumer', 'SuccessResponse', - 'SetupNeededResponse', 'CancelResponse', 'FailureResponse', - 'SUCCESS', 'FAILURE', 'CANCEL', 'SETUP_NEEDED', - ] +__all__ = [ + 'AuthRequest', + 'Consumer', + 'SuccessResponse', + 'SetupNeededResponse', + 'CancelResponse', + 'FailureResponse', + 'SUCCESS', + 'FAILURE', + 'CANCEL', + 'SETUP_NEEDED', +] def makeKVPost(request_message, server_url): @@ -342,12 +349,12 @@ class Consumer(object): try: service = disco.getNextService(self._discover) except fetchers.HTTPFetchingError as why: - raise DiscoveryFailure( - 'Error fetching XRDS document: %s' % (why.why,), None) + raise DiscoveryFailure('Error fetching XRDS document: %s' % + (why.why, ), None) if service is None: - raise DiscoveryFailure( - 'No usable OpenID services found for %s' % (user_url,), None) + raise DiscoveryFailure('No usable OpenID services found for %s' % + (user_url, ), None) else: return self.beginWithoutDiscovery(service, anonymous) @@ -418,10 +425,9 @@ class Consumer(object): pass if (response.status in ['success', 'cancel'] and - response.identity_url is not None): + response.identity_url is not None): - disco = Discovery(self.session, - response.identity_url, + disco = Discovery(self.session, response.identity_url, self.session_key_prefix) # This is OK to do even if we did not do discovery in # the first place. @@ -473,13 +479,13 @@ class DiffieHellmanSHA1ConsumerSession(object): args.update({ 'dh_modulus': cryptutil.longToBase64(self.dh.modulus), 'dh_gen': cryptutil.longToBase64(self.dh.generator), - }) + }) return args def extractSecret(self, response): - dh_server_public64 = response.getArg( - OPENID_NS, 'dh_server_public', no_default) + dh_server_public64 = response.getArg(OPENID_NS, 'dh_server_public', + no_default) enc_mac_key64 = response.getArg(OPENID_NS, 'enc_mac_key', no_default) dh_server_public = cryptutil.base64ToLong(dh_server_public64) enc_mac_key = oidutil.fromBase64(enc_mac_key64) @@ -508,6 +514,7 @@ class PlainTextConsumerSession(object): class SetupNeededError(Exception): """Internally-used exception that indicates that an immediate-mode request cancelled.""" + def __init__(self, user_setup_url=None): Exception.__init__(self, user_setup_url) self.user_setup_url = user_setup_url @@ -529,8 +536,8 @@ class TypeURIMismatch(ProtocolError): def __str__(self): s = '<%s.%s: Required type %s not found in %s for endpoint %s>' % ( - self.__class__.__module__, self.__class__.__name__, - self.expected, self.endpoint.type_uris, self.endpoint) + self.__class__.__module__, self.__class__.__name__, self.expected, + self.endpoint.type_uris, self.endpoint) return s @@ -547,8 +554,8 @@ class ServerError(Exception): def fromMessage(cls, message): """Generate a ServerError instance, extracting the error text and the error code from the message.""" - error_text = message.getArg( - OPENID_NS, 'error', '<no error message supplied>') + error_text = message.getArg(OPENID_NS, 'error', + '<no error message supplied>') error_code = message.getArg(OPENID_NS, 'error_code') return cls(error_text, error_code, message) @@ -585,7 +592,7 @@ class GenericConsumer(object): 'DH-SHA1': DiffieHellmanSHA1ConsumerSession, 'DH-SHA256': DiffieHellmanSHA256ConsumerSession, 'no-encryption': PlainTextConsumerSession, - } + } _discover = staticmethod(discover) @@ -618,8 +625,7 @@ class GenericConsumer(object): """ mode = message.getArg(OPENID_NS, 'mode', '<No mode set>') - modeMethod = getattr(self, '_complete_' + mode, - self._completeInvalid) + modeMethod = getattr(self, '_complete_' + mode, self._completeInvalid) return modeMethod(message, endpoint, return_to) @@ -631,8 +637,8 @@ class GenericConsumer(object): contact = message.getArg(OPENID_NS, 'contact') reference = message.getArg(OPENID_NS, 'reference') - return FailureResponse(endpoint, error, contact=contact, - reference=reference) + return FailureResponse( + endpoint, error, contact=contact, reference=reference) def _complete_setup_needed(self, message, endpoint, _): if not message.isOpenID2(): @@ -654,8 +660,7 @@ class GenericConsumer(object): def _completeInvalid(self, message, endpoint, _): mode = message.getArg(OPENID_NS, 'mode', '<No mode set>') - return FailureResponse(endpoint, - 'Invalid openid.mode: %r' % (mode,)) + return FailureResponse(endpoint, 'Invalid openid.mode: %r' % (mode, )) def _checkReturnTo(self, message, return_to): """Check an OpenID message and its openid.return_to value @@ -667,7 +672,7 @@ class GenericConsumer(object): try: self._verifyReturnToArgs(message.toPostArgs()) except ProtocolError as why: - logging.exception("Verifying return_to arguments: %s" % (why,)) + logging.exception("Verifying return_to arguments: %s" % (why, )) return False # Check the return_to base URL against the one in the message. @@ -727,14 +732,14 @@ class GenericConsumer(object): if not self._checkReturnTo(message, return_to): raise ProtocolError( - "return_to does not match return URL. Expected %r, got %r" - % (return_to, message.getArg(OPENID_NS, 'return_to'))) + "return_to does not match return URL. Expected %r, got %r" % + (return_to, message.getArg(OPENID_NS, 'return_to'))) # Verify discovery information: endpoint = self._verifyDiscoveryResults(message, endpoint) logging.info("Received id_res response from %s using association %s" % - (endpoint.server_url, - message.getArg(OPENID_NS, 'assoc_handle'))) + (endpoint.server_url, + message.getArg(OPENID_NS, 'assoc_handle'))) self._idResCheckSignature(message, endpoint.server_url) @@ -773,10 +778,10 @@ class GenericConsumer(object): try: timestamp, salt = splitNonce(nonce) except ValueError as why: - raise ProtocolError('Malformed nonce: %s' % (why,)) + raise ProtocolError('Malformed nonce: %s' % (why, )) if (self.store is not None and - not self.store.useNonce(server_url, timestamp, salt)): + not self.store.useNonce(server_url, timestamp, salt)): raise ProtocolError('Nonce already used or out of range') def _idResCheckSignature(self, message, server_url): @@ -793,8 +798,8 @@ class GenericConsumer(object): # automatically opens the possibility for # denial-of-service by a server that just returns expired # associations (or really short-lived associations) - raise ProtocolError( - 'Association with %s expired' % (server_url,)) + raise ProtocolError('Association with %s expired' % + (server_url, )) if not assoc.checkMessageSignature(message): raise ProtocolError('Bad signature') @@ -820,19 +825,19 @@ class GenericConsumer(object): require_fields = { OPENID2_NS: basic_fields + ['op_endpoint'], OPENID1_NS: basic_fields + ['identity'], - } + } require_sigs = { - OPENID2_NS: basic_sig_fields + ['response_nonce', - 'claimed_id', - 'assoc_handle', - 'op_endpoint'], - OPENID1_NS: basic_sig_fields, - } + OPENID2_NS: + basic_sig_fields + + ['response_nonce', 'claimed_id', 'assoc_handle', 'op_endpoint'], + OPENID1_NS: + basic_sig_fields, + } for field in require_fields[message.getOpenIDNamespace()]: if not message.hasKey(OPENID_NS, field): - raise ProtocolError('Missing required field %r' % (field,)) + raise ProtocolError('Missing required field %r' % (field, )) signed_list_str = message.getArg(OPENID_NS, 'signed', no_default) signed_list = signed_list_str.split(',') @@ -840,7 +845,7 @@ class GenericConsumer(object): for field in require_sigs[message.getOpenIDNamespace()]: # Field is present and not in signed list if message.hasKey(OPENID_NS, field) and field not in signed_list: - raise ProtocolError('"%s" not signed' % (field,)) + raise ProtocolError('"%s" not signed' % (field, )) def _verifyReturnToArgs(query): """Verify that the arguments in the return_to URL are present in this @@ -877,8 +882,8 @@ class GenericConsumer(object): bare_args = message.getArgs(BARE_NS) for pair in bare_args.items(): if pair not in parsed_args: - raise ProtocolError( - "Parameter %s not in return_to URL" % (pair[0],)) + raise ProtocolError("Parameter %s not in return_to URL" % + (pair[0], )) _verifyReturnToArgs = staticmethod(_verifyReturnToArgs) @@ -904,18 +909,16 @@ class GenericConsumer(object): to_match.local_id = resp_msg.getArg(OPENID2_NS, 'identity') # Raises a KeyError when the op_endpoint is not present - to_match.server_url = resp_msg.getArg( - OPENID2_NS, 'op_endpoint', no_default) + to_match.server_url = resp_msg.getArg(OPENID2_NS, 'op_endpoint', + no_default) # claimed_id and identifier must both be present or both # be absent - if (to_match.claimed_id is None and - to_match.local_id is not None): + if (to_match.claimed_id is None and to_match.local_id is not None): raise ProtocolError( 'openid.identity is present without openid.claimed_id') - elif (to_match.claimed_id is not None and - to_match.local_id is None): + elif (to_match.claimed_id is not None and to_match.local_id is None): raise ProtocolError( 'openid.claimed_id is present without openid.identity') @@ -932,6 +935,11 @@ class GenericConsumer(object): if not endpoint: logging.info('No pre-discovered information supplied.') endpoint = self._discoverAndVerify(to_match.claimed_id, [to_match]) + elif endpoint.isOPIdentifier(): + logging.info( + 'Pre-discovered information based on OP-ID; need to rediscover.' + ) + endpoint = self._discoverAndVerify(to_match.claimed_id, [to_match]) else: # The claimed ID matches, so we use the endpoint that we # discovered in initiation. This should be the most common @@ -943,8 +951,8 @@ class GenericConsumer(object): "Error attempting to use stored discovery information: " + str(e)) logging.info("Attempting discovery to verify endpoint") - endpoint = self._discoverAndVerify( - to_match.claimed_id, [to_match]) + endpoint = self._discoverAndVerify(to_match.claimed_id, + [to_match]) # The endpoint we return should have the claimed ID from the # message we just verified, fragment and all. @@ -1058,15 +1066,16 @@ class GenericConsumer(object): @raises DiscoveryFailure: when discovery fails. """ - logging.info('Performing discovery on %s' % (claimed_id,)) + logging.info('Performing discovery on %s' % (claimed_id, )) _, services = self._discover(claimed_id) if not services: raise DiscoveryFailure('No OpenID information found at %s' % - (claimed_id,), None) + (claimed_id, ), None) return self._verifyDiscoveredServices(claimed_id, services, to_match_endpoints) - def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): + def _verifyDiscoveredServices(self, claimed_id, services, + to_match_endpoints): """See @L{_discoverAndVerify}""" # Search the services resulting from discovery to find one @@ -1075,8 +1084,7 @@ class GenericConsumer(object): for endpoint in services: for to_match_endpoint in to_match_endpoints: try: - self._verifyDiscoverySingle( - endpoint, to_match_endpoint) + self._verifyDiscoverySingle(endpoint, to_match_endpoint) except ProtocolError as why: failure_messages.append(str(why)) else: @@ -1085,13 +1093,13 @@ class GenericConsumer(object): return endpoint else: logging.error('Discovery verification failure for %s' % - (claimed_id,)) + (claimed_id, )) for failure_message in failure_messages: logging.error(' * Endpoint mismatch: ' + failure_message) raise DiscoveryFailure( - 'No matching endpoint found after discovering %s' - % (claimed_id,), None) + 'No matching endpoint found after discovering %s' % + (claimed_id, ), None) def _checkAuth(self, message, server_url): """Make a check_authentication request to verify this message. @@ -1126,7 +1134,7 @@ class GenericConsumer(object): # Signed value is missing if val is None: - logging.info('Missing signed field %r' % (k,)) + logging.info('Missing signed field %r' % (k, )) return None check_auth_message = message.copy() @@ -1141,11 +1149,11 @@ class GenericConsumer(object): invalidate_handle = response.getArg(OPENID_NS, 'invalidate_handle') if invalidate_handle is not None: - logging.info( - 'Received "invalidate_handle" from server %s' % (server_url,)) + logging.info('Received "invalidate_handle" from server %s' % + (server_url, )) if self.store is None: logging.error('Unexpectedly got invalidate_handle without ' - 'a store!') + 'a store!') else: self.store.removeAssociation(server_url, invalidate_handle) @@ -1188,28 +1196,26 @@ class GenericConsumer(object): assoc_type, session_type = self.negotiator.getAllowedType() try: - assoc = self._requestAssociation( - endpoint, assoc_type, session_type) + assoc = self._requestAssociation(endpoint, assoc_type, + session_type) except ServerError as why: - supportedTypes = self._extractSupportedAssociationType(why, - endpoint, - assoc_type) + supportedTypes = self._extractSupportedAssociationType( + why, endpoint, assoc_type) if supportedTypes is not None: assoc_type, session_type = supportedTypes # Attempt to create an association from the assoc_type # and session_type that the server told us it # supported. try: - assoc = self._requestAssociation( - endpoint, assoc_type, session_type) + assoc = self._requestAssociation(endpoint, assoc_type, + session_type) except ServerError as why: # Do not keep trying, since it rejected the # association type that it told us to use. logging.error( 'Server %s refused its suggested association ' - 'type: session_type=%s, assoc_type=%s' - % (endpoint.server_url, session_type, - assoc_type)) + 'type: session_type=%s, assoc_type=%s' % ( + endpoint.server_url, session_type, assoc_type)) return None else: return assoc @@ -1230,16 +1236,15 @@ class GenericConsumer(object): if server_error.error_code != 'unsupported-type' or \ server_error.message.isOpenID1(): logging.error( - 'Server error when requesting an association from %r: %s' - % (endpoint.server_url, server_error.error_text)) + 'Server error when requesting an association from %r: %s' % + (endpoint.server_url, server_error.error_text)) return None # The server didn't like the association/session type # that we sent, and it sent us back a message that # might tell us how to handle it. - logging.error( - 'Unsupported association type %s: %s' % (assoc_type, - server_error.error_text,)) + logging.error('Unsupported association type %s: %s' % + (assoc_type, server_error.error_text, )) # Extract the session_type and assoc_type from the # error message @@ -1248,7 +1253,7 @@ class GenericConsumer(object): if assoc_type is None or session_type is None: logging.error('Server responded with unsupported association ' - 'session but did not supply a fallback.') + 'session but did not supply a fallback.') return None elif not self.negotiator.isAllowed(assoc_type, session_type): fmt = ('Server sent unsupported session/association type: ' @@ -1273,19 +1278,19 @@ class GenericConsumer(object): try: response = self._makeKVPost(args, endpoint.server_url) except fetchers.HTTPFetchingError as why: - logging.exception('openid.associate request failed: %s' % (why,)) + logging.exception('openid.associate request failed: %s' % (why, )) return None try: assoc = self._extractAssociation(response, assoc_session) except KeyError as why: logging.exception( - 'Missing required parameter in response from %s: %s' - % (endpoint.server_url, why)) + 'Missing required parameter in response from %s: %s' % + (endpoint.server_url, why)) return None except ProtocolError as why: - logging.exception('Protocol error parsing response from %s: %s' % ( - endpoint.server_url, why)) + logging.exception('Protocol error parsing response from %s: %s' % + (endpoint.server_url, why)) return None else: return assoc @@ -1320,7 +1325,7 @@ class GenericConsumer(object): args = { 'mode': 'associate', 'assoc_type': assoc_type, - } + } if not endpoint.compatibilityMode(): args['ns'] = OPENID2_NS @@ -1328,7 +1333,7 @@ class GenericConsumer(object): # Leave out the session type if we're in compatibility mode # *and* it's no-encryption. if (not endpoint.compatibilityMode() or - assoc_session.session_type != 'no-encryption'): + assoc_session.session_type != 'no-encryption'): args['session_type'] = assoc_session.session_type args.update(assoc_session.getRequest()) @@ -1362,7 +1367,7 @@ class GenericConsumer(object): # warning. if session_type == 'no-encryption': logging.warning('OpenID server sent "no-encryption"' - 'for OpenID 1.X') + 'for OpenID 1.X') # Missing or empty session type is the way to flag a # 'no-encryption' response. Change the session type to @@ -1393,33 +1398,32 @@ class GenericConsumer(object): """ # Extract the common fields from the response, raising an # exception if they are not found - assoc_type = assoc_response.getArg( - OPENID_NS, 'assoc_type', no_default) - assoc_handle = assoc_response.getArg( - OPENID_NS, 'assoc_handle', no_default) + assoc_type = assoc_response.getArg(OPENID_NS, 'assoc_type', no_default) + assoc_handle = assoc_response.getArg(OPENID_NS, 'assoc_handle', + no_default) # expires_in is a base-10 string. The Python parsing will # accept literals that have whitespace around them and will # accept negative values. Neither of these are really in-spec, # but we think it's OK to accept them. - expires_in_str = assoc_response.getArg( - OPENID_NS, 'expires_in', no_default) + expires_in_str = assoc_response.getArg(OPENID_NS, 'expires_in', + no_default) try: expires_in = int(expires_in_str) except ValueError as why: - raise ProtocolError('Invalid expires_in field: %s' % (why,)) + raise ProtocolError('Invalid expires_in field: %s' % (why, )) # OpenID 1 has funny association session behaviour. if assoc_response.isOpenID1(): session_type = self._getOpenID1SessionType(assoc_response) else: - session_type = assoc_response.getArg( - OPENID2_NS, 'session_type', no_default) + session_type = assoc_response.getArg(OPENID2_NS, 'session_type', + no_default) # Session type mismatch if assoc_session.session_type != session_type: if (assoc_response.isOpenID1() and - session_type == 'no-encryption'): + session_type == 'no-encryption'): # In OpenID 1, any association request can result in a # 'no-encryption' association response. Setting # assoc_session to a new no-encryption session should @@ -1448,8 +1452,8 @@ class GenericConsumer(object): fmt = 'Malformed response for %s session: %s' raise ProtocolError(fmt % (assoc_session.session_type, why)) - return Association.fromExpiresIn( - expires_in, assoc_handle, secret, assoc_type) + return Association.fromExpiresIn(expires_in, assoc_handle, secret, + assoc_type) class AuthRequest(object): @@ -1585,10 +1589,10 @@ class AuthRequest(object): realm_key = 'realm' message.updateArgs(OPENID_NS, { - realm_key: realm, - 'mode': mode, - 'return_to': return_to, - }) + realm_key: realm, + 'mode': mode, + 'return_to': return_to, + }) if not self._anonymous: if self.endpoint.isOPIdentifier(): @@ -1608,12 +1612,12 @@ class AuthRequest(object): if self.assoc: message.setArg(OPENID_NS, 'assoc_handle', self.assoc.handle) - assoc_log_msg = 'with association %s' % (self.assoc.handle,) + assoc_log_msg = 'with association %s' % (self.assoc.handle, ) else: assoc_log_msg = 'using stateless mode.' logging.info("Generated %s request to %s %s" % - (mode, self.endpoint.server_url, assoc_log_msg)) + (mode, self.endpoint.server_url, assoc_log_msg)) return message @@ -1658,8 +1662,11 @@ class AuthRequest(object): message = self.getMessage(realm, return_to, immediate) return message.toURL(self.endpoint.server_url) - def formMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def formMarkup(self, + realm, + return_to=None, + immediate=False, + form_tag_attrs=None): """Get html for a form to submit this request to the IDP. @param form_tag_attrs: Dictionary of attributes to be added to @@ -1669,11 +1676,13 @@ class AuthRequest(object): @type form_tag_attrs: {unicode: unicode} """ message = self.getMessage(realm, return_to, immediate) - return message.toFormMarkup(self.endpoint.server_url, - form_tag_attrs) + return message.toFormMarkup(self.endpoint.server_url, form_tag_attrs) - def htmlMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def htmlMarkup(self, + realm, + return_to=None, + immediate=False, + form_tag_attrs=None): """Get an autosubmitting HTML page that submits this request to the IDP. This is just a wrapper for formMarkup. @@ -1681,10 +1690,8 @@ class AuthRequest(object): @returns: str """ - return oidutil.autoSubmitHTML(self.formMarkup(realm, - return_to, - immediate, - form_tag_attrs)) + return oidutil.autoSubmitHTML( + self.formMarkup(realm, return_to, immediate, form_tag_attrs)) def shouldSendRedirect(self): """Should this OpenID authentication request be sent as a HTTP @@ -1694,6 +1701,7 @@ class AuthRequest(object): """ return self.endpoint.compatibilityMode() + FAILURE = 'failure' SUCCESS = 'success' CANCEL = 'cancel' @@ -1797,8 +1805,8 @@ class SuccessResponse(Response): for key in msg_args.keys(): if not self.isSigned(ns_uri, key): logging.info( - "SuccessResponse.getSignedNS: (%s, %s) not signed." - % (ns_uri, key)) + "SuccessResponse.getSignedNS: (%s, %s) not signed." % + (ns_uri, key)) return None return msg_args @@ -1835,20 +1843,18 @@ class SuccessResponse(Response): return self.getSigned(OPENID_NS, 'return_to') def __eq__(self, other): - return ( - (self.endpoint == other.endpoint) and - (self.identity_url == other.identity_url) and - (self.message == other.message) and - (self.signed_fields == other.signed_fields) and - (self.status == other.status)) + return ((self.endpoint == other.endpoint) and + (self.identity_url == other.identity_url) and + (self.message == other.message) and + (self.signed_fields == other.signed_fields) and + (self.status == other.status)) def __ne__(self, other): return not (self == other) def __repr__(self): return '<%s.%s id=%r signed=%r>' % ( - self.__class__.__module__, - self.__class__.__name__, + self.__class__.__module__, self.__class__.__name__, self.identity_url, self.signed_fields) @@ -1867,17 +1873,16 @@ class FailureResponse(Response): status = FAILURE - def __init__(self, endpoint, message=None, contact=None, - reference=None): + def __init__(self, endpoint, message=None, contact=None, reference=None): self.setEndpoint(endpoint) self.message = message self.contact = contact self.reference = reference def __repr__(self): - return "<%s.%s id=%r message=%r>" % ( - self.__class__.__module__, self.__class__.__name__, - self.identity_url, self.message) + return "<%s.%s id=%r message=%r>" % (self.__class__.__module__, + self.__class__.__name__, + self.identity_url, self.message) class CancelResponse(Response): diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index 544081355dc295767ba4d47c919bb09e9d1e54df..07f4a949ba713b169caefe774425acb42ea251b4 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -11,7 +11,7 @@ __all__ = [ 'OPENID_IDP_2_0_TYPE', 'OpenIDServiceEndpoint', 'discover', - ] +] import urllib.parse import logging @@ -37,6 +37,7 @@ OPENID_1_0_TYPE = 'http://openid.net/signon/1.0' from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS from openid.message import OPENID2_NS as OPENID_2_0_MESSAGE_NS + class OpenIDServiceEndpoint(object): """Object representing an OpenID service endpoint. @@ -48,11 +49,10 @@ class OpenIDServiceEndpoint(object): # ordering of this list affects yadis and XRI service discovery. openid_type_uris = [ OPENID_IDP_2_0_TYPE, - OPENID_2_0_TYPE, OPENID_1_1_TYPE, OPENID_1_0_TYPE, - ] + ] def __init__(self): self.claimed_id = None @@ -60,7 +60,7 @@ class OpenIDServiceEndpoint(object): self.type_uris = [] self.local_id = None self.canonicalID = None - self.used_yadis = False # whether this came from an XRDS + self.used_yadis = False # whether this came from an XRDS self.display_identifier = None def usesExtension(self, extension_uri): @@ -68,7 +68,7 @@ class OpenIDServiceEndpoint(object): def preferredNamespace(self): if (OPENID_IDP_2_0_TYPE in self.type_uris or - OPENID_2_0_TYPE in self.type_uris): + OPENID_2_0_TYPE in self.type_uris): return OPENID_2_0_MESSAGE_NS else: return OPENID_1_0_MESSAGE_NS @@ -78,10 +78,8 @@ class OpenIDServiceEndpoint(object): I consider C{/server} endpoints to implicitly support C{/signon}. """ - return ( - (type_uri in self.type_uris) or - (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier()) - ) + return ((type_uri in self.type_uris) or + (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier())) def getDisplayIdentifier(self): """Return the display_identifier if set, else return the claimed_id. @@ -138,11 +136,9 @@ class OpenIDServiceEndpoint(object): # specified, then this is an OpenID endpoint if type_uris and endpoint.uri is not None: openid_endpoint = cls() - openid_endpoint.parseService( - endpoint.yadis_url, - endpoint.uri, - endpoint.type_uris, - endpoint.service_element) + openid_endpoint.parseService(endpoint.yadis_url, endpoint.uri, + endpoint.type_uris, + endpoint.service_element) else: openid_endpoint = None @@ -159,20 +155,20 @@ class OpenIDServiceEndpoint(object): discovery_types = [ (OPENID_2_0_TYPE, 'openid2.provider', 'openid2.local_id'), (OPENID_1_1_TYPE, 'openid.server', 'openid.delegate'), - ] + ] link_attrs = html_parse.parseLinkAttrs(html) services = [] for type_uri, op_endpoint_rel, local_id_rel in discovery_types: - op_endpoint_url = html_parse.findFirstHref( - link_attrs, op_endpoint_rel) + op_endpoint_url = html_parse.findFirstHref(link_attrs, + op_endpoint_rel) if op_endpoint_url is None: continue service = cls() service.claimed_id = uri - service.local_id = html_parse.findFirstHref( - link_attrs, local_id_rel) + service.local_id = html_parse.findFirstHref(link_attrs, + local_id_rel) service.server_url = op_endpoint_url service.type_uris = [type_uri] @@ -182,7 +178,6 @@ class OpenIDServiceEndpoint(object): fromHTML = classmethod(fromHTML) - def fromXRDS(cls, uri, xrds): """Parse the given document as XRDS looking for OpenID services. @@ -196,7 +191,6 @@ class OpenIDServiceEndpoint(object): fromXRDS = classmethod(fromXRDS) - def fromDiscoveryResult(cls, discoveryResult): """Create endpoints from a DiscoveryResult. @@ -217,7 +211,6 @@ class OpenIDServiceEndpoint(object): fromDiscoveryResult = classmethod(fromDiscoveryResult) - def fromOPEndpointURL(cls, op_endpoint_url): """Construct an OP-Identifier OpenIDServiceEndpoint object for a given OP Endpoint URL @@ -232,7 +225,6 @@ class OpenIDServiceEndpoint(object): fromOPEndpointURL = classmethod(fromOPEndpointURL) - def __str__(self): return ("<%s.%s " "server_url=%r " @@ -240,14 +232,9 @@ class OpenIDServiceEndpoint(object): "local_id=%r " "canonicalID=%r " "used_yadis=%s " - ">" - % (self.__class__.__module__, self.__class__.__name__, - self.server_url, - self.claimed_id, - self.local_id, - self.canonicalID, - self.used_yadis)) - + ">" % (self.__class__.__module__, self.__class__.__name__, + self.server_url, self.claimed_id, self.local_id, + self.canonicalID, self.used_yadis)) def findOPLocalIdentifier(service_element, type_uris): @@ -279,8 +266,7 @@ def findOPLocalIdentifier(service_element, type_uris): # Build the list of tags that could contain the OP-Local Identifier local_id_tags = [] - if (OPENID_1_1_TYPE in type_uris or - OPENID_1_0_TYPE in type_uris): + if (OPENID_1_1_TYPE in type_uris or OPENID_1_0_TYPE in type_uris): local_id_tags.append(nsTag(OPENID_1_0_NS, 'Delegate')) if OPENID_2_0_TYPE in type_uris: @@ -295,27 +281,30 @@ def findOPLocalIdentifier(service_element, type_uris): local_id = local_id_element.text elif local_id != local_id_element.text: format = 'More than one %r tag found in one service element' - message = format % (local_id_tag,) + message = format % (local_id_tag, ) raise DiscoveryFailure(message, None) return local_id + def normalizeURL(url): """Normalize a URL, converting normalization failures to DiscoveryFailure""" try: normalized = urinorm.urinorm(url) except ValueError as why: - raise DiscoveryFailure('Normalizing identifier: %s' % (why,), None) + raise DiscoveryFailure('Normalizing identifier: %s' % (why, ), None) else: return urllib.parse.urldefrag(normalized)[0] + def normalizeXRI(xri): """Normalize an XRI, stripping its scheme if present""" if xri.startswith("xri://"): xri = xri[6:] return xri + def arrangeByType(service_list, preferred_types): """Rearrange service_list in a new list so services are ordered by types listed in preferred_types. Return the new list.""" @@ -355,6 +344,7 @@ def arrangeByType(service_list, preferred_types): return prio_services + def getOPOrUserServices(openid_services): """Extract OP Identifier services. If none found, return the rest, sorted with most preferred first according to @@ -371,6 +361,7 @@ def getOPOrUserServices(openid_services): return op_services or openid_services + def discoverYadis(uri): """Discover OpenID services for a URI. Tries Yadis and falls back on old-style <link rel='...'> discovery if Yadis fails. @@ -412,6 +403,7 @@ def discoverYadis(uri): return (yadis_url, getOPOrUserServices(openid_services)) + def discoverXRI(iname): endpoints = [] iname = normalizeXRI(iname) @@ -420,7 +412,7 @@ def discoverXRI(iname): iname, OpenIDServiceEndpoint.openid_type_uris) if canonicalID is None: - raise XRDSError('No CanonicalID found for XRI %r' % (iname,)) + raise XRDSError('No CanonicalID found for XRI %r' % (iname, )) flt = filters.mkFilter(OpenIDServiceEndpoint) for service_element in services: @@ -444,13 +436,14 @@ def discoverNoYadis(uri): if http_resp.status not in (200, 206): raise DiscoveryFailure( 'HTTP Response status from identity URL host is not 200. ' - 'Got status %r' % (http_resp.status,), http_resp) + 'Got status %r' % (http_resp.status, ), http_resp) claimed_id = http_resp.final_url - openid_services = OpenIDServiceEndpoint.fromHTML( - claimed_id, http_resp.body) + openid_services = OpenIDServiceEndpoint.fromHTML(claimed_id, + http_resp.body) return claimed_id, openid_services + def discoverURI(uri): parsed = urllib.parse.urlparse(uri) if parsed[0] and parsed[1]: @@ -464,6 +457,7 @@ def discoverURI(uri): claimed_id = normalizeURL(claimed_id) return claimed_id, openid_services + def discover(identifier): if xri.identifierScheme(identifier) == "XRI": return discoverXRI(identifier) diff --git a/openid/consumer/html_parse.py b/openid/consumer/html_parse.py index 4c81949498d59571e9fc899abb6b31ea977f70af..cc154eb5b682aa08fe76463d73db3bd490a25576 100644 --- a/openid/consumer/html_parse.py +++ b/openid/consumer/html_parse.py @@ -71,11 +71,11 @@ __all__ = ['parseLinkAttrs'] import re -flags = (re.DOTALL # Match newlines with '.' - | re.IGNORECASE - | re.VERBOSE # Allow comments and whitespace in patterns - | re.UNICODE # Make \b respect Unicode word boundaries - ) +flags = ( + re.DOTALL # Match newlines with '.' + | re.IGNORECASE | re.VERBOSE # Allow comments and whitespace in patterns + | re.UNICODE # Make \b respect Unicode word boundaries +) # Stuff to remove before we start looking for tags removed_re = re.compile(r''' @@ -126,14 +126,15 @@ tag_expr = r''' def tagMatcher(tag_name, *close_tags): if close_tags: - options = '|'.join((tag_name,) + close_tags) - closers = '(?:%s)' % (options,) + options = '|'.join((tag_name, ) + close_tags) + closers = '(?:%s)' % (options, ) else: closers = tag_name expr = tag_expr % locals() return re.compile(expr, flags) + # Must contain at least an open html and an open head tag html_find = tagMatcher('html') head_find = tagMatcher('head', 'body') @@ -165,7 +166,7 @@ replacements = { 'lt': '<', 'gt': '>', 'quot': '"', - } +} ent_replace = re.compile(r'&(%s);' % '|'.join(list(replacements.keys()))) @@ -228,8 +229,8 @@ def parseLinkAttrs(html, ignore_errors=False): # Either q_val or unq_val must be present, but not both # unq_val is a True (non-empty) value if it is present - attr_name, q_val, unq_val = attr_mo.group( - 'attr_name', 'q_val', 'unq_val') + attr_name, q_val, unq_val = attr_mo.group('attr_name', 'q_val', + 'unq_val') attr_val = ent_replace.sub(replaceEnt, unq_val or q_val) link_attrs[attr_name] = attr_val diff --git a/openid/cryptutil.py b/openid/cryptutil.py index c68bb11813304faf0dc623006370b2a9a8bd7cad..7fdd628519b6bb676c2d6a709ed28a3d33f61acc 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -24,7 +24,7 @@ __all__ = [ 'randrange', 'sha1', 'sha256', - ] +] import hmac import os @@ -40,6 +40,7 @@ class HashContainer(object): self.new = hash_constructor self.digest_size = hash_constructor().digest_size + sha1_module = HashContainer(hashlib.sha1) sha256_module = HashContainer(hashlib.sha256) @@ -74,7 +75,6 @@ def sha256(s): SHA256_AVAILABLE = True - try: from Crypto.Util.number import long_to_bytes, bytes_to_long except ImportError: diff --git a/openid/extension.py b/openid/extension.py index 129d304c5490fd73624c113ed31c085e6ded8b89..0052f0033b88a9d6480cf9de767caa5ce67cb64d 100644 --- a/openid/extension.py +++ b/openid/extension.py @@ -34,14 +34,15 @@ class Extension(object): warnings.warn( 'Passing None to Extension.toMessage is deprecated. ' 'Creating a message assuming you want OpenID 2.', - DeprecationWarning, stacklevel=2) + DeprecationWarning, + stacklevel=2) message = message_module.Message(message_module.OPENID2_NS) implicit = message.isOpenID1() try: - message.namespaces.addAlias(self.ns_uri, self.ns_alias, - implicit=implicit) + message.namespaces.addAlias( + self.ns_uri, self.ns_alias, implicit=implicit) except KeyError: if message.namespaces.getAlias(self.ns_uri) != self.ns_alias: raise diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 68d1c08f627e701f1903b843d3d75df363189da3..d36fd00872cf20be9a61470b09b5026a8f886dfa 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -10,7 +10,7 @@ __all__ = [ 'FetchResponse', 'StoreRequest', 'StoreResponse', - ] +] from openid import extension from openid.server.trustroot import TrustRoot @@ -31,9 +31,9 @@ def checkAlias(alias): found. Return None if the alias is valid. """ if ',' in alias: - raise AXError("Alias %r must not contain comma" % (alias,)) + raise AXError("Alias %r must not contain comma" % (alias, )) if '.' in alias: - raise AXError("Alias %r must not contain period" % (alias,)) + raise AXError("Alias %r must not contain period" % (alias, )) class AXError(ValueError): @@ -85,8 +85,7 @@ class AXMessage(extension.Extension): if not mode: raise NotAXMessage() else: - raise AXError( - 'Expected mode %r; got %r' % (self.mode, mode)) + raise AXError('Expected mode %r; got %r' % (self.mode, mode)) def _newArgs(self): """Return a set of attribute exchange arguments containing the @@ -175,8 +174,8 @@ def toTypeURIs(namespace_map, alias_list_s): for alias in alias_list_s.split(','): type_uri = namespace_map.getNamespaceURI(alias) if type_uri is None: - raise KeyError( - 'No type is defined for attribute name %r' % (alias,)) + raise KeyError('No type is defined for attribute name %r' % + (alias, )) else: uris.append(type_uri) @@ -215,8 +214,8 @@ class FetchRequest(AXMessage): present in this fetch request. """ if attribute.type_uri in self.requested_attributes: - raise KeyError('The attribute %r has already been requested' - % (attribute.type_uri,)) + raise KeyError('The attribute %r has already been requested' % + (attribute.type_uri, )) self.requested_attributes[attribute.type_uri] = attribute @@ -318,14 +317,15 @@ class FetchRequest(AXMessage): message.getArg(OPENID_NS, 'return_to')) if not realm: - raise AXError(("Cannot validate update_url %r " + - "against absent realm") % (self.update_url,)) + raise AXError( + ("Cannot validate update_url %r " + "against absent realm") + % (self.update_url, )) tr = TrustRoot.parse(realm) if not tr.validateURL(self.update_url): raise AXError( "Update URL %r failed validation against realm %r" % - (self.update_url, realm,)) + (self.update_url, realm, )) return self @@ -368,12 +368,11 @@ class FetchRequest(AXMessage): if count <= 0: raise AXError( "Count %r must be greater than zero, got %r" % - (count_key, count_s,)) + (count_key, count_s, )) except ValueError: if count_s != UNLIMITED_VALUES: - raise AXError( - "Invalid count value for %r: %r" % - (count_key, count_s,)) + raise AXError("Invalid count value for %r: %r" % + (count_key, count_s, )) count = count_s else: count = 1 @@ -391,9 +390,9 @@ class FetchRequest(AXMessage): for type_uri in aliases.iterNamespaceURIs(): if type_uri not in all_type_uris: - raise AXError( - 'Type URI %r was in the request but not ' - 'present in "required" or "if_available"' % (type_uri,)) + raise AXError('Type URI %r was in the request but not ' + 'present in "required" or "if_available"' % + (type_uri, )) self.update_url = ax_args.get('update_url') @@ -560,8 +559,7 @@ class AXKeyValueMessage(AXMessage): elif len(values) == 1: return values[0] else: - raise AXError( - 'More than one value present for %r' % (type_uri,)) + raise AXError('More than one value present for %r' % (type_uri, )) def get(self, type_uri): """Get the list of values for this attribute in the @@ -644,8 +642,8 @@ class FetchResponse(AXKeyValueMessage): for type_uri in self.data: if type_uri not in self.request: raise KeyError( - 'Response attribute not present in request: %r' - % (type_uri,)) + 'Response attribute not present in request: %r' % + (type_uri, )) for attr_info in self.request.iterAttrs(): # Copy the aliases from the request so that reading @@ -665,7 +663,7 @@ class FetchResponse(AXKeyValueMessage): (attr_info.count < len(values)): raise AXError( 'More than the number of requested values were ' - 'specified for %r' % (attr_info.type_uri,)) + 'specified for %r' % (attr_info.type_uri, )) kv_args = self._getExtensionKVArgs(aliases) @@ -680,8 +678,8 @@ class FetchResponse(AXKeyValueMessage): kv_args['type.' + alias] = attr_info.type_uri kv_args['count.' + alias] = '0' - update_url = ((self.request and self.request.update_url) - or self.update_url) + update_url = ((self.request and self.request.update_url) or + self.update_url) if update_url: ax_args['update_url'] = update_url @@ -762,7 +760,7 @@ class StoreResponse(AXMessage): if succeeded and error_message is not None: raise AXError('An error message may only be included in a ' - 'failing fetch response') + 'failing fetch response') if succeeded: self.mode = self.SUCCESS_MODE else: diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index 41824ca4c04b4833e4a18f984b4a4952a88556d4..d0587e316080de8846e7a2f80075b995395a878a 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -13,7 +13,7 @@ __all__ = [ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] from openid.extension import Extension import re @@ -75,8 +75,8 @@ class Request(Extension): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies) - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies) + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -147,8 +147,9 @@ class Request(Extension): @returntype: [str] """ - return list(filter(self.preferred_auth_policies.__contains__, - supported_types)) + return list( + filter(self.preferred_auth_policies.__contains__, supported_types)) + Request.ns_uri = ns_uri @@ -160,7 +161,9 @@ class Response(Extension): ns_alias = 'pape' - def __init__(self, auth_policies=None, auth_time=None, + def __init__(self, + auth_policies=None, + auth_time=None, nist_auth_level=None): super(Response, self).__init__() if auth_policies: @@ -262,7 +265,7 @@ class Response(Extension): else: ns_args = { 'auth_policies': ' '.join(self.auth_policies), - } + } if self.nist_auth_level is not None: if self.nist_auth_level not in list(range(0, 5)): @@ -278,4 +281,5 @@ class Response(Extension): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index 3e376ce3d6ea8c1d3e44694bab7ec0f569226f4e..1441dd3acde7fb2bd08a690e1c821d4318defc6b 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -15,7 +15,7 @@ __all__ = [ 'AUTH_MULTI_FACTOR_PHYSICAL', 'LEVELS_NIST', 'LEVELS_JISA', - ] +] from openid.extension import Extension import warnings @@ -42,7 +42,7 @@ class PAPEExtension(Extension): _default_auth_level_aliases = { 'nist': LEVELS_NIST, 'jisa': LEVELS_JISA, - } + } def __init__(self): self.auth_level_aliases = self._default_auth_level_aliases.copy() @@ -73,7 +73,7 @@ class PAPEExtension(Extension): def _generateAlias(self): """Return an unused auth level alias""" for i in range(1000): - alias = 'cust%d' % (i,) + alias = 'cust%d' % (i, ) if alias not in self.auth_level_aliases: return alias @@ -112,7 +112,9 @@ class Request(PAPEExtension): ns_alias = 'pape' - def __init__(self, preferred_auth_policies=None, max_auth_age=None, + def __init__(self, + preferred_auth_policies=None, + max_auth_age=None, preferred_auth_level_types=None): super(Request, self).__init__() if preferred_auth_policies is None: @@ -153,8 +155,8 @@ class Request(PAPEExtension): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies), - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies), + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -164,7 +166,7 @@ class Request(PAPEExtension): for auth_level_uri in self.preferred_auth_level_types: alias = self._getAlias(auth_level_uri) - ns_args['auth_level.ns.%s' % (alias,)] = auth_level_uri + ns_args['auth_level.ns.%s' % (alias, )] = auth_level_uri preferred_types.append(alias) ns_args['preferred_auth_level_types'] = ' '.join(preferred_types) @@ -234,7 +236,7 @@ class Request(PAPEExtension): aliases = preferred_auth_level_types.strip().split() for alias in aliases: - key = 'auth_level.ns.%s' % (alias,) + key = 'auth_level.ns.%s' % (alias, ) try: uri = args[key] except KeyError: @@ -246,7 +248,7 @@ class Request(PAPEExtension): if uri is None: if strict: raise ValueError('preferred auth level %r is not ' - 'defined in this message' % (alias,)) + 'defined in this message' % (alias, )) else: self.addAuthLevel(uri, alias) @@ -266,8 +268,9 @@ class Request(PAPEExtension): @returntype: [str] """ - return list(filter(self.preferred_auth_policies.__contains__, - supported_types)) + return list( + filter(self.preferred_auth_policies.__contains__, supported_types)) + Request.ns_uri = ns_uri @@ -282,8 +285,7 @@ class Response(PAPEExtension): ns_alias = 'pape' - def __init__(self, auth_policies=None, auth_time=None, - auth_levels=None): + def __init__(self, auth_policies=None, auth_time=None, auth_levels=None): super(Response, self).__init__() if auth_policies: self.auth_policies = auth_policies @@ -403,7 +405,7 @@ class Response(PAPEExtension): if (len(auth_policies) > 1 and strict and AUTH_NONE in auth_policies): raise ValueError('Got some auth policies, as well as the special ' - '"none" URI: %r' % (auth_policies,)) + '"none" URI: %r' % (auth_policies, )) if 'none' in auth_policies: msg = '"none" used as a policy URI (see PAPE draft < 5)' @@ -412,8 +414,9 @@ class Response(PAPEExtension): else: warnings.warn(msg, stacklevel=2) - auth_policies = [u for u in auth_policies - if u not in ['none', AUTH_NONE]] + auth_policies = [ + u for u in auth_policies if u not in ['none', AUTH_NONE] + ] self.auth_policies = auth_policies @@ -426,7 +429,7 @@ class Response(PAPEExtension): continue try: - uri = args['auth_level.ns.%s' % (alias,)] + uri = args['auth_level.ns.%s' % (alias, )] except KeyError: if is_openid1: uri = self._default_auth_level_aliases.get(alias) @@ -435,8 +438,8 @@ class Response(PAPEExtension): if uri is None: if strict: - raise ValueError( - 'Undefined auth level alias: %r' % (alias,)) + raise ValueError('Undefined auth level alias: %r' % + (alias, )) else: self.setAuthLevel(uri, val, alias) @@ -458,13 +461,13 @@ class Response(PAPEExtension): } else: ns_args = { - 'auth_policies':' '.join(self.auth_policies), - } + 'auth_policies': ' '.join(self.auth_policies), + } for level_type, level in self.auth_levels.items(): alias = self._getAlias(level_type) - ns_args['auth_level.ns.%s' % (alias,)] = level_type - ns_args['auth_level.%s' % (alias,)] = str(level) + ns_args['auth_level.ns.%s' % (alias, )] = level_type + ns_args['auth_level.%s' % (alias, )] = str(level) if self.auth_time is not None: if not TIME_VALIDATOR.match(self.auth_time): @@ -474,4 +477,5 @@ class Response(PAPEExtension): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 201f4f04ba5d07cfc6ccd26d7771511417130484..c0b2090c85e8a63774031331b194b97d4805a8e3 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -41,10 +41,10 @@ from openid.extension import Extension import logging try: - str #pylint:disable-msg=W0104 + str #pylint:disable-msg=W0104 except NameError: # For Python 2.2 - str = (str, str) #pylint:disable-msg=W0622 + str = (str, str) #pylint:disable-msg=W0622 __all__ = [ 'SRegRequest', @@ -54,20 +54,21 @@ __all__ = [ 'ns_uri_1_0', 'ns_uri_1_1', 'supportsSReg', - ] +] # The data fields that are listed in the sreg spec data_fields = { - 'fullname':'Full Name', - 'nickname':'Nickname', - 'dob':'Date of Birth', - 'email':'E-mail Address', - 'gender':'Gender', - 'postcode':'Postal Code', - 'country':'Country', - 'language':'Language', - 'timezone':'Time Zone', - } + 'fullname': 'Full Name', + 'nickname': 'Nickname', + 'dob': 'Date of Birth', + 'email': 'E-mail Address', + 'gender': 'Gender', + 'postcode': 'Postal Code', + 'country': 'Country', + 'language': 'Language', + 'timezone': 'Time Zone', +} + def checkFieldName(field_name): """Check to see that the given value is a valid simple @@ -78,7 +79,8 @@ def checkFieldName(field_name): """ if field_name not in data_fields: raise ValueError('%r is not a defined simple registration field' % - (field_name,)) + (field_name, )) + # URI used in the wild for Yadis documents advertising simple # registration support @@ -95,8 +97,9 @@ ns_uri = ns_uri_1_1 try: registerNamespaceAlias(ns_uri_1_1, 'sreg') except NamespaceAliasRegistrationError as e: - logging.exception('registerNamespaceAlias(%r, %r) failed: %s' % (ns_uri_1_1, - 'sreg', str(e),)) + logging.exception('registerNamespaceAlias(%r, %r) failed: %s' % + (ns_uri_1_1, 'sreg', str(e), )) + def supportsSReg(endpoint): """Does the given endpoint advertise support for simple @@ -111,6 +114,7 @@ def supportsSReg(endpoint): return (endpoint.usesExtension(ns_uri_1_1) or endpoint.usesExtension(ns_uri_1_0)) + class SRegNamespaceError(ValueError): """The simple registration namespace was not found and could not be created using the expected name (there's another extension @@ -125,6 +129,7 @@ class SRegNamespaceError(ValueError): the message that is being processed. """ + def getSRegNS(message): """Extract the simple registration namespace URI from the given OpenID message. Handles OpenID 1 and 2, as well as both sreg @@ -163,7 +168,8 @@ def getSRegNS(message): # we know that sreg_ns_uri defined, because it's defined in the # else clause of the loop as well, so disable the warning - return sreg_ns_uri #pylint:disable-msg=W0631 + return sreg_ns_uri #pylint:disable-msg=W0631 + class SRegRequest(Extension): """An object to hold the state of a simple registration request. @@ -185,7 +191,10 @@ class SRegRequest(Extension): ns_alias = 'sreg' - def __init__(self, required=None, optional=None, policy_url=None, + def __init__(self, + required=None, + optional=None, + policy_url=None, sreg_ns_uri=ns_uri): """Initialize an empty simple registration request""" Extension.__init__(self) @@ -287,8 +296,7 @@ class SRegRequest(Extension): def __contains__(self, field_name): """Was this field in the request?""" - return (field_name in self.required or - field_name in self.optional) + return (field_name in self.required or field_name in self.optional) def requestField(self, field_name, required=False, strict=False): """Request the specified field from the OpenID user @@ -345,7 +353,7 @@ class SRegRequest(Extension): """ if isinstance(field_names, str): raise TypeError('Fields should be passed as a list of ' - 'strings (not %r)' % (type(field_names),)) + 'strings (not %r)' % (type(field_names), )) for field_name in field_names: self.requestField(field_name, required, strict=strict) @@ -373,6 +381,7 @@ class SRegRequest(Extension): return args + class SRegResponse(Extension): """Represents the data returned in a simple registration response inside of an OpenID C{id_res} response. This object will be diff --git a/openid/fetchers.py b/openid/fetchers.py index 68ff06976f5032c0b6d13da5e5469211ed73a98d..c109cd2fd026f8414ae5d53012232891f107e2a6 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -3,9 +3,10 @@ This module contains the HTTP fetcher interface and several implementations. """ -__all__ = ['fetch', 'getDefaultFetcher', 'setDefaultFetcher', 'HTTPResponse', - 'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', - 'HTTPError'] +__all__ = [ + 'fetch', 'getDefaultFetcher', 'setDefaultFetcher', 'HTTPResponse', + 'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', 'HTTPError' +] import urllib.request import urllib.error @@ -59,6 +60,7 @@ def createHTTPFetcher(): return fetcher + # Contains the currently set HTTP fetcher. If it is set to None, the # library will call createHTTPFetcher() to set it. Do not access this # variable outside of this module. @@ -123,8 +125,7 @@ class HTTPResponse(object): self.body = body def __repr__(self): - return "<%s status %s for %s>" % (self.__class__.__name__, - self.status, + return "<%s status %s for %s>" % (self.__class__.__name__, self.status, self.final_url) @@ -170,6 +171,7 @@ class HTTPFetchingError(Exception): @ivar why: The exception that caused this exception """ + def __init__(self, why=None): Exception.__init__(self, why) self.why = why @@ -211,14 +213,13 @@ class Urllib2Fetcher(HTTPFetcher): def fetch(self, url, body=None, headers=None): if not _allowedURL(url): - raise ValueError('Bad URL scheme: %r' % (url,)) + raise ValueError('Bad URL scheme: %r' % (url, )) if headers is None: headers = {} - headers.setdefault( - 'User-Agent', - "%s Python-urllib/%s" % (USER_AGENT, urllib.request.__version__)) + headers.setdefault('User-Agent', "%s Python-urllib/%s" % + (USER_AGENT, urllib.request.__version__)) if isinstance(body, str): body = bytes(body, encoding="utf-8") @@ -321,26 +322,16 @@ class CurlHTTPFetcher(HTTPFetcher): def _parseHeaders(self, header_file): header_file.seek(0) - # Remove the status line from the beginning of the input - unused_http_status_line = header_file.readline().lower() - if unused_http_status_line.startswith(b'http/1.1 100 '): - unused_http_status_line = header_file.readline() - unused_http_status_line = header_file.readline() - - lines = [line.decode().strip() for line in header_file] - - # and the blank line from the end - empty_line = lines.pop() - if empty_line: - raise HTTPError("No blank line at end of headers: %r" % (line,)) + # Remove all non "name: value" header lines from the input + lines = [line.decode().strip() for line in header_file if b':' in line] headers = {} for line in lines: try: name, value = line.split(':', 1) except ValueError: - raise HTTPError( - "Malformed HTTP header line in response: %r" % (line,)) + raise HTTPError("Malformed HTTP header line in response: %r" % + (line, )) value = value.strip() @@ -363,7 +354,7 @@ class CurlHTTPFetcher(HTTPFetcher): headers = {} headers.setdefault('User-Agent', - "%s %s" % (USER_AGENT, pycurl.version,)) + "%s %s" % (USER_AGENT, pycurl.version, )) header_list = [] if headers is not None: @@ -385,7 +376,7 @@ class CurlHTTPFetcher(HTTPFetcher): while off > 0: if not self._checkURL(url): - raise HTTPError("Fetching URL not allowed: %r" % (url,)) + raise HTTPError("Fetching URL not allowed: %r" % (url, )) data = io.BytesIO() @@ -426,7 +417,7 @@ class CurlHTTPFetcher(HTTPFetcher): off = stop - int(time.time()) - raise HTTPError("Timed out fetching: %r" % (url,)) + raise HTTPError("Timed out fetching: %r" % (url, )) finally: c.close() @@ -474,7 +465,7 @@ class HTTPLib2Fetcher(HTTPFetcher): # httplib2 doesn't check to make sure that the URL's scheme is # 'http' so we do it here. if not (url.startswith('http://') or url.startswith('https://')): - raise ValueError('URL is not a HTTP URL: %r' % (url,)) + raise ValueError('URL is not a HTTP URL: %r' % (url, )) httplib2_response, content = self.httplib2.request( url, method, body=body, headers=headers) @@ -496,8 +487,7 @@ class HTTPLib2Fetcher(HTTPFetcher): final_url = url return HTTPResponse( - body=content.decode(), # TODO Don't assume ASCII + body=content.decode(), # TODO Don't assume ASCII final_url=final_url, headers=dict(list(httplib2_response.items())), - status=httplib2_response.status, - ) + status=httplib2_response.status, ) diff --git a/openid/kvform.py b/openid/kvform.py index 5a210a07ffebe62db773f44eedd0792049fddaeb..571bd9e9596ebf2c823ececa8f90099450edb918 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -17,6 +17,7 @@ def seqToKV(seq, strict=False): @return: A string representation of the sequence @rtype: bytes """ + def err(msg): formatted = 'seqToKV warning: %s: %r' % (msg, seq) if strict: @@ -34,27 +35,28 @@ def seqToKV(seq, strict=False): if '\n' in k: raise KVFormError( - 'Invalid input for seqToKV: key contains newline: %r' % (k,)) + 'Invalid input for seqToKV: key contains newline: %r' % (k, )) if ':' in k: raise KVFormError( - 'Invalid input for seqToKV: key contains colon: %r' % (k,)) + 'Invalid input for seqToKV: key contains colon: %r' % (k, )) if k.strip() != k: - err('Key has whitespace at beginning or end: %r' % (k,)) + err('Key has whitespace at beginning or end: %r' % (k, )) if isinstance(v, bytes): v = v.decode('utf-8') elif not isinstance(v, str): - err('Converting value to string: %r' % (v,)) + err('Converting value to string: %r' % (v, )) v = str(v) if '\n' in v: raise KVFormError( - 'Invalid input for seqToKV: value contains newline: %r' % (v,)) + 'Invalid input for seqToKV: value contains newline: %r' % + (v, )) if v.strip() != v: - err('Value has whitespace at beginning or end: %r' % (v,)) + err('Value has whitespace at beginning or end: %r' % (v, )) lines.append(k + ':' + v + '\n') @@ -71,6 +73,7 @@ def kvToSeq(data, strict=False): @return str """ + def err(msg): formatted = 'kvToSeq warning: %s: %r' % (msg, data) if strict: @@ -106,7 +109,7 @@ def kvToSeq(data, strict=False): err(fmt % (line_num, k)) if not k_s: - err('In line %d, got empty key' % (line_num,)) + err('In line %d, got empty key' % (line_num, )) v_s = v.strip() if v_s != v: diff --git a/openid/message.py b/openid/message.py index 68735fa8e97412fe859688dcc4ec33b4136cc082..325c14d414e3244df57909376d8f5196d65ac2c1 100644 --- a/openid/message.py +++ b/openid/message.py @@ -1,8 +1,10 @@ """Extension argument processing code """ -__all__ = ['Message', 'NamespaceMap', 'no_default', 'registerNamespaceAlias', - 'OPENID_NS', 'BARE_NS', 'OPENID1_NS', 'OPENID2_NS', 'SREG_URI', - 'IDENTIFIER_SELECT'] +__all__ = [ + 'Message', 'NamespaceMap', 'no_default', 'registerNamespaceAlias', + 'OPENID_NS', 'BARE_NS', 'OPENID1_NS', 'OPENID2_NS', 'SREG_URI', + 'IDENTIFIER_SELECT' +] import copy import warnings @@ -51,12 +53,29 @@ OPENID1_URL_LIMIT = 2047 # All OpenID protocol fields. Used to check namespace aliases. OPENID_PROTOCOL_FIELDS = [ - 'ns', 'mode', 'error', 'return_to', 'contact', 'reference', - 'signed', 'assoc_type', 'session_type', 'dh_modulus', 'dh_gen', - 'dh_consumer_public', 'claimed_id', 'identity', 'realm', - 'invalidate_handle', 'op_endpoint', 'response_nonce', 'sig', - 'assoc_handle', 'trust_root', 'openid', - ] + 'ns', + 'mode', + 'error', + 'return_to', + 'contact', + 'reference', + 'signed', + 'assoc_type', + 'session_type', + 'dh_modulus', + 'dh_gen', + 'dh_consumer_public', + 'claimed_id', + 'identity', + 'realm', + 'invalidate_handle', + 'op_endpoint', + 'response_nonce', + 'sig', + 'assoc_handle', + 'trust_root', + 'openid', +] class UndefinedOpenIDNamespace(ValueError): @@ -69,10 +88,11 @@ class InvalidOpenIDNamespace(ValueError): For recognized values, see L{Message.allowed_openid_namespaces} """ + def __str__(self): s = "Invalid OpenID Namespace" if self.args: - s += " %r" % (self.args[0],) + s += " %r" % (self.args[0], ) return s @@ -107,11 +127,11 @@ def registerNamespaceAlias(namespace_uri, alias): if namespace_uri in list(registered_aliases.values()): raise NamespaceAliasRegistrationError( - 'Namespace uri %r already registered' % (namespace_uri,)) + 'Namespace uri %r already registered' % (namespace_uri, )) if alias in registered_aliases: - raise NamespaceAliasRegistrationError( - 'Alias %r already registered' % (alias,)) + raise NamespaceAliasRegistrationError('Alias %r already registered' % + (alias, )) registered_aliases[alias] = namespace_uri @@ -157,9 +177,8 @@ class Message(object): openid_args = {} for key, value in args.items(): if isinstance(value, list): - raise TypeError( - "query dict must have one value for each key, " - "not lists of values. Query is %r" % (args,)) + raise TypeError("query dict must have one value for each key, " + "not lists of values. Query is %r" % (args, )) try: prefix, rest = key.split('.', 1) @@ -306,7 +325,9 @@ class Message(object): return kvargs - def toFormMarkup(self, action_url, form_tag_attrs=None, + def toFormMarkup(self, + action_url, + form_tag_attrs=None, submit_text="Continue"): """Generate HTML form markup that contains the values in this message, to be HTTP POSTed as x-www-form-urlencoded UTF-8. @@ -345,14 +366,17 @@ class Message(object): form.attrib['enctype'] = 'application/x-www-form-urlencoded' for name, value in self.toPostArgs().items(): - attrs = {'type': 'hidden', - 'name': oidutil.toUnicode(name), - 'value': oidutil.toUnicode(value)} + attrs = { + 'type': 'hidden', + 'name': oidutil.toUnicode(name), + 'value': oidutil.toUnicode(value) + } form.append(ElementTree.Element('input', attrs)) submit = ElementTree.Element( 'input', - {'type': 'submit', 'value': oidutil.toUnicode(submit_text)}) + {'type': 'submit', + 'value': oidutil.toUnicode(submit_text)}) form.append(submit) return str(ElementTree.tostring(form, encoding='utf-8'), @@ -393,16 +417,18 @@ class Message(object): if namespace != BARE_NS and not isinstance(namespace, str): raise TypeError( - "Namespace must be BARE_NS, OPENID_NS or a string. got %r" - % (namespace,)) + "Namespace must be BARE_NS, OPENID_NS or a string. got %r" % + (namespace, )) if namespace != BARE_NS and ':' not in namespace: fmt = 'OpenID 2.0 namespace identifiers SHOULD be URIs. Got %r' - warnings.warn(fmt % (namespace,), DeprecationWarning) + warnings.warn(fmt % (namespace, ), DeprecationWarning) if namespace == 'sreg': fmt = 'Using %r instead of "sreg" as namespace' - warnings.warn(fmt % (SREG_URI,), DeprecationWarning,) + warnings.warn( + fmt % (SREG_URI, ), + DeprecationWarning, ) return SREG_URI return namespace @@ -508,8 +534,7 @@ class Message(object): def __repr__(self): return "<%s.%s %r>" % (self.__class__.__module__, - self.__class__.__name__, - self.args) + self.__class__.__name__, self.args) def __eq__(self, other): return self.args == other.args @@ -549,6 +574,7 @@ class Message(object): class NamespaceMap(object): """Maintains a bijective map between namespace uris and aliases. """ + def __init__(self): self.alias_to_namespace = {} self.namespace_to_alias = {} @@ -594,17 +620,14 @@ class NamespaceMap(object): # Check that there is not a namespace already defined for # the desired alias current_namespace_uri = self.alias_to_namespace.get(desired_alias) - if (current_namespace_uri is not None - and current_namespace_uri != namespace_uri): + if (current_namespace_uri is not None and + current_namespace_uri != namespace_uri): fmt = ('Cannot map %r to alias %r. ' '%r is already mapped to alias %r') - msg = fmt % ( - namespace_uri, - desired_alias, - current_namespace_uri, - desired_alias) + msg = fmt % (namespace_uri, desired_alias, current_namespace_uri, + desired_alias) raise KeyError(msg) # Check that there is not already a (different) alias for diff --git a/openid/oidutil.py b/openid/oidutil.py index 70637c44e21d019530a36861e09a1b1e90d69e62..754277f7cf43558f7454a032cc2fee73d3bfa743 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -5,8 +5,10 @@ For users of this library, the C{L{log}} function is probably the most interesting. """ -__all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML', - 'toUnicode'] +__all__ = [ + 'log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML', + 'toUnicode' +] import binascii import logging @@ -14,19 +16,17 @@ import logging # import urllib.parse as urlparse from urllib.parse import urlencode - xxe_safe_elementtree_modules = [ 'defusedxml.cElementTree', 'defusedxml.ElementTree', - ] - +] elementtree_modules = [ 'xml.etree.cElementTree', 'xml.etree.ElementTree', 'cElementTree', 'elementtree.ElementTree', - ] +] def toUnicode(value): @@ -83,8 +83,8 @@ def importSafeElementTree(module_names=None): return importElementTree(module_names) except ImportError: raise ImportError('Unable to find a ElementTree module ' - 'that is not vulnerable to XXE. ' - 'Tried importing %r' % (module_names,)) + 'that is not vulnerable to XXE. ' + 'Tried importing %r' % (module_names, )) def importElementTree(module_names=None): @@ -121,8 +121,7 @@ def importElementTree(module_names=None): else: raise ImportError('No ElementTree library found. ' 'You may need to install one. ' - 'Tried importing %r' % (module_names,) - ) + 'Tried importing %r' % (module_names, )) def log(message, level=0): @@ -144,7 +143,7 @@ def log(message, level=0): """ logging.error("This is a legacy log message, please use the " - "logging module. Message: %s", message) + "logging module. Message: %s", message) def appendArgs(url, args): @@ -234,4 +233,4 @@ class Symbol(object): return hash((self.__class__, self.name)) def __repr__(self): - return '<Symbol %s>' % (self.name,) + return '<Symbol %s>' % (self.name, ) diff --git a/openid/server/server.py b/openid/server/server.py index 53167abe4ac1a8be5d00eb38d2f1abb08b098963..a9e7325b43ba63409b7c38fe632d50ed55d7fd2d 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -138,9 +138,9 @@ HTTP_ERROR = 400 BROWSER_REQUEST_MODES = ['checkid_setup', 'checkid_immediate'] -ENCODE_KVFORM = ('kvform',) -ENCODE_URL = ('URL/redirect',) -ENCODE_HTML_FORM = ('HTML form',) +ENCODE_KVFORM = ('kvform', ) +ENCODE_URL = ('URL/redirect', ) +ENCODE_HTML_FORM = ('HTML form', ) UNUSED = None @@ -206,11 +206,9 @@ class CheckAuthRequest(OpenIDRequest): self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') self.sig = message.getArg(OPENID_NS, 'sig') - if (self.assoc_handle is None or - self.sig is None): + if (self.assoc_handle is None or self.sig is None): fmt = "%s request missing required parameter from message %s" - raise ProtocolError( - message, text=fmt % (self.mode, message)) + raise ProtocolError(message, text=fmt % (self.mode, message)) self.invalidate_handle = message.getArg(OPENID_NS, 'invalidate_handle') @@ -246,21 +244,21 @@ class CheckAuthRequest(OpenIDRequest): response.fields.setArg(OPENID_NS, 'is_valid', valid_str) if self.invalidate_handle: - assoc = signatory.getAssociation(self.invalidate_handle, - dumb=False) + assoc = signatory.getAssociation( + self.invalidate_handle, dumb=False) if not assoc: - response.fields.setArg( - OPENID_NS, 'invalidate_handle', self.invalidate_handle) + response.fields.setArg(OPENID_NS, 'invalidate_handle', + self.invalidate_handle) return response def __str__(self): if self.invalidate_handle: - ih = " invalidate? %r" % (self.invalidate_handle,) + ih = " invalidate? %r" % (self.invalidate_handle, ) else: ih = "" s = "<%s handle: %r sig: %r: signed: %r%s>" % ( - self.__class__.__name__, self.assoc_handle, - self.sig, self.signed, ih) + self.__class__.__name__, self.assoc_handle, self.sig, self.signed, + ih) return s @@ -328,18 +326,17 @@ class DiffieHellmanSHA1ServerSession(object): """ dh_modulus = message.getArg(OPENID_NS, 'dh_modulus') dh_gen = message.getArg(OPENID_NS, 'dh_gen') - if (dh_modulus is None and dh_gen is not None or - dh_gen is None and dh_modulus is not None): + if (dh_modulus is None and dh_gen is not None or dh_gen is None and + dh_modulus is not None): if dh_modulus is None: missing = 'modulus' else: missing = 'generator' - raise ProtocolError(message, - 'If non-default modulus or generator is ' - 'supplied, both must be supplied. Missing %s' - % (missing,)) + raise ProtocolError( + message, 'If non-default modulus or generator is ' + 'supplied, both must be supplied. Missing %s' % (missing, )) if dh_modulus or dh_gen: dh_modulus = cryptutil.base64ToLong(dh_modulus) @@ -351,7 +348,7 @@ class DiffieHellmanSHA1ServerSession(object): consumer_pubkey = message.getArg(OPENID_NS, 'dh_consumer_public') if consumer_pubkey is None: raise ProtocolError(message, "Public key for DH-SHA1 session " - "not found in message %s" % (message,)) + "not found in message %s" % (message, )) consumer_pubkey = cryptutil.base64ToLong(consumer_pubkey) @@ -360,13 +357,12 @@ class DiffieHellmanSHA1ServerSession(object): fromMessage = classmethod(fromMessage) def answer(self, secret): - mac_key = self.dh.xorSecret(self.consumer_pubkey, - secret, + mac_key = self.dh.xorSecret(self.consumer_pubkey, secret, self.hash_func) return { 'dh_server_public': cryptutil.longToBase64(self.dh.public), 'enc_mac_key': oidutil.toBase64(mac_key), - } + } class DiffieHellmanSHA256ServerSession(DiffieHellmanSHA1ServerSession): @@ -398,7 +394,7 @@ class AssociateRequest(OpenIDRequest): 'no-encryption': PlainTextServerSession, 'DH-SHA1': DiffieHellmanSHA1ServerSession, 'DH-SHA256': DiffieHellmanSHA256ServerSession, - } + } def __init__(self, session, assoc_type): """Construct me. @@ -430,14 +426,14 @@ class AssociateRequest(OpenIDRequest): else: session_type = message.getArg(OPENID2_NS, 'session_type') if session_type is None: - raise ProtocolError(message, - text="session_type missing from request") + raise ProtocolError( + message, text="session_type missing from request") try: session_class = klass.session_classes[session_type] except KeyError: raise ProtocolError(message, - "Unknown session type %r" % (session_type,)) + "Unknown session type %r" % (session_type, )) try: session = session_class.fromMessage(message) @@ -472,7 +468,7 @@ class AssociateRequest(OpenIDRequest): 'expires_in': str(assoc.expiresIn), 'assoc_type': self.assoc_type, 'assoc_handle': assoc.handle, - }) + }) response.fields.updateArgs(OPENID_NS, self.session.answer(assoc.secret)) @@ -480,12 +476,14 @@ class AssociateRequest(OpenIDRequest): self.message.isOpenID1()): # The session type "no-encryption" did not have a name # in OpenID v1, it was just omitted. - response.fields.setArg( - OPENID_NS, 'session_type', self.session.session_type) + response.fields.setArg(OPENID_NS, 'session_type', + self.session.session_type) return response - def answerUnsupported(self, message, preferred_association_type=None, + def answerUnsupported(self, + message, + preferred_association_type=None, preferred_session_type=None): """Respond to this request indicating that the association type or association session type is not supported.""" @@ -497,12 +495,12 @@ class AssociateRequest(OpenIDRequest): response.fields.setArg(OPENID_NS, 'error', message) if preferred_association_type: - response.fields.setArg( - OPENID_NS, 'assoc_type', preferred_association_type) + response.fields.setArg(OPENID_NS, 'assoc_type', + preferred_association_type) if preferred_session_type: - response.fields.setArg( - OPENID_NS, 'session_type', preferred_session_type) + response.fields.setArg(OPENID_NS, 'session_type', + preferred_session_type) return response @@ -542,8 +540,14 @@ class CheckIDRequest(OpenIDRequest): @type assoc_handle: str """ - def __init__(self, identity, return_to, trust_root=None, immediate=False, - assoc_handle=None, op_endpoint=None, claimed_id=None): + def __init__(self, + identity, + return_to, + trust_root=None, + immediate=False, + assoc_handle=None, + op_endpoint=None, + claimed_id=None): """Construct me. These parameters are assigned directly as class attributes, see @@ -573,9 +577,12 @@ class CheckIDRequest(OpenIDRequest): self.message = None def _getNamespace(self): - warnings.warn('The "namespace" attribute of CheckIDRequest objects ' - 'is deprecated. Use "message.getOpenIDNamespace()" ' - 'instead', DeprecationWarning, stacklevel=2) + warnings.warn( + 'The "namespace" attribute of CheckIDRequest objects ' + 'is deprecated. Use "message.getOpenIDNamespace()" ' + 'instead', + DeprecationWarning, + stacklevel=2) return self.message.getOpenIDNamespace() namespace = property(_getNamespace) @@ -614,7 +621,7 @@ class CheckIDRequest(OpenIDRequest): self.return_to = message.getArg(OPENID_NS, 'return_to') if message.isOpenID1() and not self.return_to: fmt = "Missing required field 'return_to' from %r" - raise ProtocolError(message, text=fmt % (message,)) + raise ProtocolError(message, text=fmt % (message, )) self.identity = message.getArg(OPENID_NS, 'identity') self.claimed_id = message.getArg(OPENID_NS, 'claimed_id') @@ -644,13 +651,14 @@ class CheckIDRequest(OpenIDRequest): # Using 'or' here is slightly different than sending a default # argument to getArg, as it will treat no value and an empty # string as equivalent. - self.trust_root = (message.getArg(OPENID_NS, trust_root_param) - or self.return_to) + self.trust_root = (message.getArg(OPENID_NS, trust_root_param) or + self.return_to) if not message.isOpenID1(): if self.return_to is self.trust_root is None: - raise ProtocolError(message, "openid.realm required when " + - "openid.return_to absent") + raise ProtocolError( + message, + "openid.realm required when " + "openid.return_to absent") self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') @@ -786,7 +794,7 @@ class CheckIDRequest(OpenIDRequest): # You should pay attention to it now. raise RuntimeError("%s should be constructed with op_endpoint " "to respond to OpenID 2.0 messages." % - (self,)) + (self, )) server_url = self.op_endpoint if allow: @@ -807,7 +815,7 @@ class CheckIDRequest(OpenIDRequest): if claimed_id and self.message.isOpenID1(): namespace = self.message.getOpenIDNamespace() raise VersionError("claimed_id is new in OpenID 2.0 and not " - "available for %s" % (namespace,)) + "available for %s" % (namespace, )) if allow: if self.identity == IDENTIFIER_SELECT: @@ -824,7 +832,7 @@ class CheckIDRequest(OpenIDRequest): normalized_answer_identity = urinorm(identity) if (normalized_request_identity != - normalized_answer_identity): + normalized_answer_identity): raise ValueError( "Request was for identity %r, cannot reply " "with identity %r" % (self.identity, identity)) @@ -838,30 +846,29 @@ class CheckIDRequest(OpenIDRequest): if identity: raise ValueError( "This request specified no identity and you " - "supplied %r" % (identity,)) + "supplied %r" % (identity, )) response_identity = None if self.message.isOpenID1() and response_identity is None: raise ValueError( "Request was an OpenID 1 request, so response must " - "include an identifier." - ) + "include an identifier.") response.fields.updateArgs(OPENID_NS, { 'mode': mode, 'return_to': self.return_to, 'response_nonce': mkNonce(), - }) + }) if server_url: response.fields.setArg(OPENID_NS, 'op_endpoint', server_url) if response_identity is not None: - response.fields.setArg( - OPENID_NS, 'identity', response_identity) + response.fields.setArg(OPENID_NS, 'identity', + response_identity) if self.message.isOpenID2(): - response.fields.setArg( - OPENID_NS, 'claimed_id', response_claimed_id) + response.fields.setArg(OPENID_NS, 'claimed_id', + response_claimed_id) else: response.fields.setArg(OPENID_NS, 'mode', mode) if self.immediate: @@ -870,9 +877,13 @@ class CheckIDRequest(OpenIDRequest): "in OpenID 1.x immediate mode.") # Make a new request just like me, but with immediate=False. setup_request = self.__class__( - self.identity, self.return_to, self.trust_root, - immediate=False, assoc_handle=self.assoc_handle, - op_endpoint=self.op_endpoint, claimed_id=self.claimed_id) + self.identity, + self.return_to, + self.trust_root, + immediate=False, + assoc_handle=self.assoc_handle, + op_endpoint=self.op_endpoint, + claimed_id=self.claimed_id) # XXX: This API is weird. setup_request.message = self.message @@ -899,10 +910,12 @@ class CheckIDRequest(OpenIDRequest): # in both the client and server code, so Requests are Encodable too. # That's right, code imported from alternate realities all for the # love of you, id_res/user_setup_url. - q = {'mode': self.mode, - 'identity': self.identity, - 'claimed_id': self.claimed_id, - 'return_to': self.return_to} + q = { + 'mode': self.mode, + 'identity': self.identity, + 'claimed_id': self.claimed_id, + 'return_to': self.return_to + } if self.trust_root: if self.message.isOpenID1(): q['trust_root'] = self.trust_root @@ -942,11 +955,9 @@ class CheckIDRequest(OpenIDRequest): return response.toURL(self.return_to) def __repr__(self): - return '<%s id:%r im:%s tr:%r ah:%r>' % (self.__class__.__name__, - self.identity, - self.immediate, - self.trust_root, - self.assoc_handle) + return '<%s id:%r im:%s tr:%r ah:%r>' % ( + self.__class__.__name__, self.identity, self.immediate, + self.trust_root, self.assoc_handle) class OpenIDResponse(object): @@ -980,10 +991,8 @@ class OpenIDResponse(object): self.fields = Message(request.namespace) def __str__(self): - return "%s for %s: %s" % ( - self.__class__.__name__, - self.request.__class__.__name__, - self.fields) + return "%s for %s: %s" % (self.__class__.__name__, + self.request.__class__.__name__, self.fields) def toFormMarkup(self, form_tag_attrs=None): """Returns the form markup for this response. @@ -997,8 +1006,8 @@ class OpenIDResponse(object): @since: 2.1.0 """ - return self.fields.toFormMarkup(self.request.return_to, - form_tag_attrs=form_tag_attrs) + return self.fields.toFormMarkup( + self.request.return_to, form_tag_attrs=form_tag_attrs) def toHTML(self, form_tag_attrs=None): """Returns an HTML document that auto-submits the form markup @@ -1162,16 +1171,14 @@ class Signatory(object): assoc = self.getAssociation(assoc_handle, dumb=True) if not assoc: logging.error("failed to get assoc with handle %r to verify " - "message %r" - % (assoc_handle, message)) + "message %r" % (assoc_handle, message)) return False try: valid = assoc.checkMessageSignature(message) except ValueError as ex: - logging.exception("Error in verifying %s with %s: %s" % (message, - assoc, - ex)) + logging.exception("Error in verifying %s with %s: %s" % + (message, assoc, ex)) return False return valid @@ -1196,20 +1203,20 @@ class Signatory(object): # is expired, we still need to know some properties of the # association so that we may preserve those properties when # creating the fallback association. - assoc = self.getAssociation(assoc_handle, dumb=False, - checkExpiration=False) + assoc = self.getAssociation( + assoc_handle, dumb=False, checkExpiration=False) if not assoc or assoc.expiresIn <= 0: # fall back to dumb mode - signed_response.fields.setArg( - OPENID_NS, 'invalidate_handle', assoc_handle) + signed_response.fields.setArg(OPENID_NS, 'invalidate_handle', + assoc_handle) assoc_type = assoc and assoc.assoc_type or 'HMAC-SHA1' if assoc and assoc.expiresIn <= 0: # now do the clean-up that the disabled checkExpiration # code didn't get to do. self.invalidate(assoc_handle, dumb=False) - assoc = self.createAssociation(dumb=True, - assoc_type=assoc_type) + assoc = self.createAssociation( + dumb=True, assoc_type=assoc_type) else: # dumb mode. assoc = self.createAssociation(dumb=True) @@ -1237,8 +1244,8 @@ class Signatory(object): uniq = oidutil.toBase64(cryptutil.getBytes(4)) handle = '{%s}{%x}{%s}' % (assoc_type, int(time.time()), uniq) - assoc = Association.fromExpiresIn( - self.SECRET_LIFETIME, handle, secret, assoc_type) + assoc = Association.fromExpiresIn(self.SECRET_LIFETIME, handle, secret, + assoc_type) if dumb: key = self._dumb_key @@ -1275,8 +1282,8 @@ class Signatory(object): assoc = self.store.getAssociation(key, assoc_handle) if assoc is not None and assoc.expiresIn <= 0: logging.info("requested %sdumb key %r is expired (by %s seconds)" % - ((not dumb) and 'not-' or '', - assoc_handle, assoc.expiresIn)) + ((not dumb) and 'not-' or '', assoc_handle, + assoc.expiresIn)) if checkExpiration: self.store.removeAssociation(key, assoc_handle) assoc = None @@ -1321,8 +1328,8 @@ class Encoder(object): wr.code = HTTP_ERROR elif encode_as == ENCODE_URL: location = response.encodeToURL() - wr = self.responseFactory(code=HTTP_REDIRECT, - headers={'location': location}) + wr = self.responseFactory( + code=HTTP_REDIRECT, headers={'location': location}) elif encode_as == ENCODE_HTML_FORM: wr = self.responseFactory(code=HTTP_OK, body=response.toHTML()) else: @@ -1358,9 +1365,8 @@ class SigningEncoder(Encoder): # an adapter to make the interfaces quite match. if (not isinstance(response, Exception)) and response.needsSigning(): if not self.signatory: - raise ValueError( - "Must have a store to sign this request: %s" % - (response,), response) + raise ValueError("Must have a store to sign this request: %s" % + (response, ), response) if response.fields.hasKey(OPENID_NS, 'sig'): raise AlreadySigned(response) response = self.signatory.sign(response) @@ -1376,7 +1382,7 @@ class Decoder(object): 'checkid_immediate': CheckIDRequest.fromMessage, 'check_authentication': CheckAuthRequest.fromMessage, 'associate': AssociateRequest.fromMessage, - } + } def __init__(self, server): """Construct a Decoder. @@ -1420,12 +1426,11 @@ class Decoder(object): mode = message.getArg(OPENID_NS, 'mode') if not mode: fmt = "No mode value in message %s" - raise ProtocolError(message, text=fmt % (message,)) + raise ProtocolError(message, text=fmt % (message, )) handler = self._handlers.get(mode, self.defaultDecoder) return handler(message, self.server.op_endpoint) - def defaultDecoder(self, message, server): """Called to decode queries when no handler for that mode is found. @@ -1434,7 +1439,7 @@ class Decoder(object): """ mode = message.getArg(OPENID_NS, 'mode') fmt = "Unrecognized OpenID mode %r" - raise ProtocolError(message, text=fmt % (mode,)) + raise ProtocolError(message, text=fmt % (mode, )) class Server(object): @@ -1486,13 +1491,12 @@ class Server(object): @type negotiator: L{openid.association.SessionNegotiator} """ - def __init__( - self, - store, - op_endpoint=None, - signatoryClass=Signatory, - encoderClass=SigningEncoder, - decoderClass=Decoder): + def __init__(self, + store, + op_endpoint=None, + signatoryClass=Signatory, + encoderClass=SigningEncoder, + decoderClass=Decoder): """A new L{Server}. @param store: The back-end where my associations are stored. @@ -1514,10 +1518,11 @@ class Server(object): self.negotiator = default_negotiator.copy() if not op_endpoint: - warnings.warn("%s.%s constructor requires op_endpoint parameter " - "for OpenID 2.0 servers" % - (self.__class__.__module__, self.__class__.__name__), - stacklevel=2) + warnings.warn( + "%s.%s constructor requires op_endpoint parameter " + "for OpenID 2.0 servers" % + (self.__class__.__module__, self.__class__.__name__), + stacklevel=2) self.op_endpoint = op_endpoint def handleRequest(self, request): @@ -1557,18 +1562,16 @@ class Server(object): assoc_type = request.assoc_type session_type = request.session.session_type if self.negotiator.isAllowed(assoc_type, session_type): - assoc = self.signatory.createAssociation(dumb=False, - assoc_type=assoc_type) + assoc = self.signatory.createAssociation( + dumb=False, assoc_type=assoc_type) return request.answer(assoc) else: message = ('Association type %r is not supported with ' 'session type %r' % (assoc_type, session_type)) (preferred_assoc_type, preferred_session_type) = \ self.negotiator.getAllowedType() - return request.answerUnsupported( - message, - preferred_assoc_type, - preferred_session_type) + return request.answerUnsupported(message, preferred_assoc_type, + preferred_session_type) def decodeRequest(self, query): """Transform query parameters into an L{OpenIDRequest}. @@ -1748,11 +1751,9 @@ class EncodingError(Exception): def __str__(self): if self.explanation: - s = '%s: %s' % (self.__class__.__name__, - self.explanation) + s = '%s: %s' % (self.__class__.__name__, self.explanation) else: - s = '%s for Response %s' % ( - self.__class__.__name__, self.response) + s = '%s for Response %s' % (self.__class__.__name__, self.response) return s @@ -1775,6 +1776,7 @@ class UntrustedReturnURL(ProtocolError): class MalformedReturnURL(ProtocolError): """The return_to URL doesn't look like a valid URL.""" + def __init__(self, openid_message, return_to): self.return_to = return_to ProtocolError.__init__(self, openid_message) diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index e8719e93b9e4bae66e23872d53a45d5cb06ff44d..84f08d68c300b060fd87aaf3d22c40575bafdb65 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -15,7 +15,7 @@ __all__ = [ 'extractReturnToURLs', 'returnToMatches', 'verifyReturnTo', - ] +] from openid import urinorm from openid.yadis import services @@ -27,37 +27,33 @@ import logging ############################################ _protocols = ['http', 'https'] _top_level_domains = [ - 'ac', 'ad', 'ae', 'aero', 'af', 'ag', 'ai', 'al', 'am', 'an', - 'ao', 'aq', 'ar', 'arpa', 'as', 'asia', 'at', 'au', 'aw', - 'ax', 'az', 'ba', 'bb', 'bd', 'be', 'bf', 'bg', 'bh', 'bi', - 'biz', 'bj', 'bm', 'bn', 'bo', 'br', 'bs', 'bt', 'bv', 'bw', - 'by', 'bz', 'ca', 'cat', 'cc', 'cd', 'cf', 'cg', 'ch', 'ci', - 'ck', 'cl', 'cm', 'cn', 'co', 'com', 'coop', 'cr', 'cu', 'cv', - 'cx', 'cy', 'cz', 'de', 'dj', 'dk', 'dm', 'do', 'dz', 'ec', - 'edu', 'ee', 'eg', 'er', 'es', 'et', 'eu', 'fi', 'fj', 'fk', - 'fm', 'fo', 'fr', 'ga', 'gb', 'gd', 'ge', 'gf', 'gg', 'gh', - 'gi', 'gl', 'gm', 'gn', 'gov', 'gp', 'gq', 'gr', 'gs', 'gt', - 'gu', 'gw', 'gy', 'hk', 'hm', 'hn', 'hr', 'ht', 'hu', 'id', - 'ie', 'il', 'im', 'in', 'info', 'int', 'io', 'iq', 'ir', 'is', - 'it', 'je', 'jm', 'jo', 'jobs', 'jp', 'ke', 'kg', 'kh', 'ki', - 'km', 'kn', 'kp', 'kr', 'kw', 'ky', 'kz', 'la', 'lb', 'lc', - 'li', 'lk', 'lr', 'ls', 'lt', 'lu', 'lv', 'ly', 'ma', 'mc', - 'md', 'me', 'mg', 'mh', 'mil', 'mk', 'ml', 'mm', 'mn', 'mo', - 'mobi', 'mp', 'mq', 'mr', 'ms', 'mt', 'mu', 'museum', 'mv', - 'mw', 'mx', 'my', 'mz', 'na', 'name', 'nc', 'ne', 'net', 'nf', - 'ng', 'ni', 'nl', 'no', 'np', 'nr', 'nu', 'nz', 'om', 'org', - 'pa', 'pe', 'pf', 'pg', 'ph', 'pk', 'pl', 'pm', 'pn', 'pr', - 'pro', 'ps', 'pt', 'pw', 'py', 'qa', 're', 'ro', 'rs', 'ru', - 'rw', 'sa', 'sb', 'sc', 'sd', 'se', 'sg', 'sh', 'si', 'sj', - 'sk', 'sl', 'sm', 'sn', 'so', 'sr', 'st', 'su', 'sv', 'sy', - 'sz', 'tc', 'td', 'tel', 'tf', 'tg', 'th', 'tj', 'tk', 'tl', - 'tm', 'tn', 'to', 'tp', 'tr', 'travel', 'tt', 'tv', 'tw', - 'tz', 'ua', 'ug', 'uk', 'us', 'uy', 'uz', 'va', 'vc', 've', - 'vg', 'vi', 'vn', 'vu', 'wf', 'ws', 'xn--0zwm56d', - 'xn--11b5bs3a9aj6g', 'xn--80akhbyknj4f', 'xn--9t4b11yi5a', - 'xn--deba0ad', 'xn--g6w251d', 'xn--hgbk6aj7f53bba', - 'xn--hlcj6aya9esc7a', 'xn--jxalpdlp', 'xn--kgbechtv', - 'xn--zckzah', 'ye', 'yt', 'yu', 'za', 'zm', 'zw'] + 'ac', 'ad', 'ae', 'aero', 'af', 'ag', 'ai', 'al', 'am', 'an', 'ao', 'aq', + 'ar', 'arpa', 'as', 'asia', 'at', 'au', 'aw', 'ax', 'az', 'ba', 'bb', 'bd', + 'be', 'bf', 'bg', 'bh', 'bi', 'biz', 'bj', 'bm', 'bn', 'bo', 'br', 'bs', + 'bt', 'bv', 'bw', 'by', 'bz', 'ca', 'cat', 'cc', 'cd', 'cf', 'cg', 'ch', + 'ci', 'ck', 'cl', 'cm', 'cn', 'co', 'com', 'coop', 'cr', 'cu', 'cv', 'cx', + 'cy', 'cz', 'de', 'dj', 'dk', 'dm', 'do', 'dz', 'ec', 'edu', 'ee', 'eg', + 'er', 'es', 'et', 'eu', 'fi', 'fj', 'fk', 'fm', 'fo', 'fr', 'ga', 'gb', + 'gd', 'ge', 'gf', 'gg', 'gh', 'gi', 'gl', 'gm', 'gn', 'gov', 'gp', 'gq', + 'gr', 'gs', 'gt', 'gu', 'gw', 'gy', 'hk', 'hm', 'hn', 'hr', 'ht', 'hu', + 'id', 'ie', 'il', 'im', 'in', 'info', 'int', 'io', 'iq', 'ir', 'is', 'it', + 'je', 'jm', 'jo', 'jobs', 'jp', 'ke', 'kg', 'kh', 'ki', 'km', 'kn', 'kp', + 'kr', 'kw', 'ky', 'kz', 'la', 'lb', 'lc', 'li', 'lk', 'lr', 'ls', 'lt', + 'lu', 'lv', 'ly', 'ma', 'mc', 'md', 'me', 'mg', 'mh', 'mil', 'mk', 'ml', + 'mm', 'mn', 'mo', 'mobi', 'mp', 'mq', 'mr', 'ms', 'mt', 'mu', 'museum', + 'mv', 'mw', 'mx', 'my', 'mz', 'na', 'name', 'nc', 'ne', 'net', 'nf', 'ng', + 'ni', 'nl', 'no', 'np', 'nr', 'nu', 'nz', 'om', 'org', 'pa', 'pe', 'pf', + 'pg', 'ph', 'pk', 'pl', 'pm', 'pn', 'pr', 'pro', 'ps', 'pt', 'pw', 'py', + 'qa', 're', 'ro', 'rs', 'ru', 'rw', 'sa', 'sb', 'sc', 'sd', 'se', 'sg', + 'sh', 'si', 'sj', 'sk', 'sl', 'sm', 'sn', 'so', 'sr', 'st', 'su', 'sv', + 'sy', 'sz', 'tc', 'td', 'tel', 'tf', 'tg', 'th', 'tj', 'tk', 'tl', 'tm', + 'tn', 'to', 'tp', 'tr', 'travel', 'tt', 'tv', 'tw', 'tz', 'ua', 'ug', 'uk', + 'us', 'uy', 'uz', 'va', 'vc', 've', 'vg', 'vi', 'vn', 'vu', 'wf', 'ws', + 'xn--0zwm56d', 'xn--11b5bs3a9aj6g', 'xn--80akhbyknj4f', 'xn--9t4b11yi5a', + 'xn--deba0ad', 'xn--g6w251d', 'xn--hgbk6aj7f53bba', 'xn--hlcj6aya9esc7a', + 'xn--jxalpdlp', 'xn--kgbechtv', 'xn--zckzah', 'ye', 'yt', 'yu', 'za', 'zm', + 'zw' +] # Build from RFC3986, section 3.2.2. Used to reject hosts with invalid # characters. @@ -70,15 +66,15 @@ class RealmVerificationRedirected(Exception): @since: 2.1.0 """ + def __init__(self, relying_party_url, rp_url_after_redirects): self.relying_party_url = relying_party_url self.rp_url_after_redirects = rp_url_after_redirects def __str__(self): return ("Attempting to verify %r resulted in " - "redirect to %r" % - (self.relying_party_url, - self.rp_url_after_redirects)) + "redirect to %r" % (self.relying_party_url, + self.rp_url_after_redirects)) def _parseURL(url): @@ -222,8 +218,7 @@ class TrustRoot(object): if not self.wildcard: if host != self.host: return False - elif ((not host.endswith(self.host)) and - ('.' + host) != self.host): + elif ((not host.endswith(self.host)) and ('.' + host) != self.host): return False if path != self.path: @@ -243,8 +238,7 @@ class TrustRoot(object): else: allowed = '?/' - return (self.path[-1] in allowed or - path[path_len] in allowed) + return (self.path[-1] in allowed or path[path_len] in allowed) return True @@ -352,12 +346,14 @@ class TrustRoot(object): def __str__(self): return repr(self) + # The URI for relying party discovery, used in realm verification. # # XXX: This should probably live somewhere else (like in # openid.consumer or openid.yadis somewhere) RP_RETURN_TO_URL_TYPE = 'http://specs.openid.net/auth/2.0/return_to' + def _extractReturnURL(endpoint): """If the endpoint is a relying party OpenID return_to endpoint, return the endpoint URL. Otherwise, return None. @@ -380,6 +376,7 @@ def _extractReturnURL(endpoint): else: return None + def returnToMatches(allowed_return_to_urls, return_to): """Is the return_to URL under one of the supplied allowed return_to URLs? @@ -394,20 +391,20 @@ def returnToMatches(allowed_return_to_urls, return_to): # a wildcard. return_realm = TrustRoot.parse(allowed_return_to) - if (# Parses as a trust root - return_realm is not None and + if ( # Parses as a trust root + return_realm is not None and - # Does not have a wildcard - not return_realm.wildcard and + # Does not have a wildcard + not return_realm.wildcard and - # Matches the return_to that we passed in with it - return_realm.validateURL(return_to) - ): + # Matches the return_to that we passed in with it + return_realm.validateURL(return_to)): return True # No URL in the list matched return False + def getAllowedReturnURLs(relying_party_url): """Given a relying party discovery URL return a list of return_to URLs. @@ -418,11 +415,12 @@ def getAllowedReturnURLs(relying_party_url): if rp_url_after_redirects != relying_party_url: # Verification caused a redirect - raise RealmVerificationRedirected( - relying_party_url, rp_url_after_redirects) + raise RealmVerificationRedirected(relying_party_url, + rp_url_after_redirects) return return_to_urls + # _vrfy parameter is there to make testing easier def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): """Verify that a return_to URL is valid for the given realm. @@ -452,5 +450,5 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): return True else: logging.error("Failed to validate return_to %r for realm %r, was not " - "in %s" % (return_to, realm_str, allowable_urls)) + "in %s" % (return_to, realm_str, allowable_urls)) return False diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 76ea2a71a5267fc7378c037c5028251f65081539..99ee48558518c7d49ba818ee5893d206967e30fc 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -60,6 +60,7 @@ def _removeIfPresent(filename): # File was present return 1 + def _ensureDir(dir_name): """Create dir_name as a directory if it does not exist. If it exists, make sure that it is, in fact, a directory. @@ -74,6 +75,7 @@ def _ensureDir(dir_name): if why.errno != EEXIST or not os.path.isdir(dir_name): raise + class FileOpenIDStore(OpenIDStore): """ This is a filesystem-based store for OpenID associations and @@ -113,7 +115,7 @@ class FileOpenIDStore(OpenIDStore): # directory self.temp_dir = os.path.join(directory, 'temp') - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds self._setup() @@ -313,8 +315,8 @@ class FileOpenIDStore(OpenIDStore): url_hash = _safe64(server_url) salt_hash = _safe64(salt) - filename = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain, - url_hash, salt_hash) + filename = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain, url_hash, + salt_hash) filename = os.path.join(self.nonce_dir, filename) try: @@ -331,7 +333,10 @@ class FileOpenIDStore(OpenIDStore): def _allAssocs(self): all_associations = [] - association_filenames = [os.path.join(self.association_dir, filename) for filename in os.listdir(self.association_dir)] + association_filenames = [ + os.path.join(self.association_dir, filename) + for filename in os.listdir(self.association_dir) + ] for association_filename in association_filenames: try: association_file = open(association_filename, 'rb') diff --git a/openid/store/interface.py b/openid/store/interface.py index bb90972f7f3e25d84e41b4aaec53073b6445785d..63776572d00834f96150b83c28d99c51347f0035 100644 --- a/openid/store/interface.py +++ b/openid/store/interface.py @@ -3,6 +3,7 @@ This module contains the definition of the C{L{OpenIDStore}} interface. """ + class OpenIDStore(object): """ This is the interface for the store objects the OpenID library diff --git a/openid/store/memstore.py b/openid/store/memstore.py index 365bd7930a0bbe538f5d62ad0c95878c2562b2bf..21e3e69c09170e7263fdd795280dd7b24a033452 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -54,6 +54,7 @@ class MemoryStore(object): Use for single long-running processes. No persistence supplied. """ + def __init__(self): self.server_assocs = {} self.nonces = {} diff --git a/openid/store/nonce.py b/openid/store/nonce.py index e9337a8a41303551f871ac49e5f1f9c75041c5e5..06c12149a4b12c4aa1ae32073838a0abe0e9c357 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -2,7 +2,7 @@ __all__ = [ 'split', 'mkNonce', 'checkTimestamp', - ] +] from openid import cryptutil from time import strptime, strftime, gmtime, time @@ -19,6 +19,7 @@ SKEW = 60 * 60 * 5 time_fmt = '%Y-%m-%dT%H:%M:%SZ' time_str_len = len('0000-00-00T00:00:00Z') + def split(nonce_string): """Extract a timestamp from the given nonce string @@ -34,12 +35,13 @@ def split(nonce_string): timestamp_str = nonce_string[:time_str_len] try: timestamp = timegm(strptime(timestamp_str, time_fmt)) - except AssertionError: # Python 2.2 + except AssertionError: # Python 2.2 timestamp = -1 if timestamp < 0: raise ValueError('time out of range') return timestamp, nonce_string[time_str_len:] + def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): """Is the timestamp that is part of the specified nonce string within the allowed clock-skew of the current time? @@ -76,6 +78,7 @@ def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): # the past return past <= stamp <= future + def mkNonce(when=None): """Generate a nonce with the current timestamp diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 690ed19bc61aba05f76b8bb1a98ba34b556c98f3..2005ee57b10dac9831929cedb121218bf36fba80 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -106,14 +106,14 @@ class SQLStore(OpenIDStore): self._table_names = { 'associations': associations_table or self.associations_table, 'nonces': nonces_table or self.nonces_table, - } - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + } + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds # DB API extension: search for "Connection Attributes .Error, # .ProgrammingError, etc." in # http://www.python.org/dev/peps/pep-0249/ if (hasattr(self.conn, 'IntegrityError') and - hasattr(self.conn, 'OperationalError')): + hasattr(self.conn, 'OperationalError')): self.exceptions = self.conn if not (hasattr(self.exceptions, 'IntegrityError') and @@ -143,6 +143,7 @@ class SQLStore(OpenIDStore): def _execSQL(self, sql_name, *args): sql = self._getSQL(sql_name) + # Kludge because we have reports of postgresql not quoting # arguments if they are passed in as unicode instead of str. # Currently the strings in our tables just have ascii in them, @@ -152,6 +153,7 @@ class SQLStore(OpenIDStore): return str(arg) else: return arg + str_args = list(map(unicode_to_str, args)) self.cur.execute(sql, str_args) @@ -161,12 +163,14 @@ class SQLStore(OpenIDStore): # as an attribute of this object and executes it. if attr[:3] == 'db_': sql_name = attr[3:] + '_sql' + def func(*args): return self._execSQL(sql_name, *args) + setattr(self, attr, func) return func else: - raise AttributeError('Attribute %r not found' % (attr,)) + raise AttributeError('Attribute %r not found' % (attr, )) def _callInTransaction(self, func, *args, **kwargs): """Execute the given function inside of a transaction, with an @@ -207,13 +211,9 @@ class SQLStore(OpenIDStore): Association -> NoneType """ a = association - self.db_set_assoc( - server_url, - a.handle, - self.blobEncode(a.secret), - a.issued, - a.lifetime, - a.assoc_type) + self.db_set_assoc(server_url, a.handle, + self.blobEncode(a.secret), a.issued, a.lifetime, + a.assoc_type) storeAssociation = _inTxn(txn_storeAssociation) @@ -257,7 +257,7 @@ class SQLStore(OpenIDStore): (str, str) -> bool """ self.db_remove_assoc(server_url, handle) - return self.cur.rowcount > 0 # -1 is undefined + return self.cur.rowcount > 0 # -1 is undefined removeAssociation = _inTxn(txn_removeAssociation) @@ -362,6 +362,7 @@ class SQLiteStore(SQLStore): else: raise + class MySQLStore(SQLStore): """ This is a MySQL-based specialization of C{L{SQLStore}}. @@ -457,7 +458,8 @@ class PostgreSQLStore(SQLStore): ); """ - def db_set_assoc(self, server_url, handle, secret, issued, lifetime, assoc_type): + def db_set_assoc(self, server_url, handle, secret, issued, lifetime, + assoc_type): """ Set an association. This is implemented as a method because REPLACE INTO is not supported by PostgreSQL (and is not diff --git a/openid/test/__init__.py b/openid/test/__init__.py index cb3c3b2169a593150cf979fa08ddcbda3e643047..ca0d036c53ac86659d0caf84c9f210f70a6e9901 100644 --- a/openid/test/__init__.py +++ b/openid/test/__init__.py @@ -35,7 +35,7 @@ def specialCaseTests(): try: test_mod = __import__(module_name, {}, {}, [None]) except ImportError: - print(('Failed to import test %r' % (module_name,))) + print(('Failed to import test %r' % (module_name, ))) else: suite.addTest(unittest.FunctionTestCase(test_mod.test)) @@ -72,7 +72,7 @@ def pyUnitTests(): test_modules = [ __import__('openid.test.test_{}'.format(name), {}, {}, ['unused']) for name in test_module_names - ] + ] try: from openid.test import test_examples @@ -103,7 +103,7 @@ def pyUnitTests(): 'test_urinorm', 'test_yadis_discover', 'trustroot', - ] + ] loader = unittest.TestLoader() suite = unittest.TestSuite() @@ -118,7 +118,7 @@ def pyUnitTests(): except AttributeError: # because the AttributeError doesn't actually say which # object it was. - print(("Error loading tests from %s:" % (name,))) + print(("Error loading tests from %s:" % (name, ))) raise return suite @@ -159,7 +159,7 @@ def djangoExampleTests(): import djopenid.server.models import djopenid.consumer.models - print ("Testing Django examples:") + print("Testing Django examples:") runner = django.test.simple.DjangoTestSuiteRunner() return runner.run_tests(['server', 'consumer']) @@ -167,9 +167,8 @@ def djangoExampleTests(): # These tests do get put into a test suite, so we could run them with the # other tests, but django also establishes a test database for them, so we # let it do that thing instead. - return django.test.simple.run_tests([djopenid.server.models, - djopenid.consumer.models]) - + return django.test.simple.run_tests( + [djopenid.server.models, djopenid.consumer.models]) def test_suite(): diff --git a/openid/test/cryptutil.py b/openid/test/cryptutil.py index f02cb3c86d1c7a5e32bd08cf98474629c6d1b964..ded655fd1b8eec95e349a46738a79e403512b53c 100644 --- a/openid/test/cryptutil.py +++ b/openid/test/cryptutil.py @@ -18,8 +18,8 @@ def test_cryptrand(): assert len(t) == 32 assert s != t - a = cryptutil.randrange(2 ** 128) - b = cryptutil.randrange(2 ** 128) + a = cryptutil.randrange(2**128) + b = cryptutil.randrange(2**128) assert type(a) is int assert type(b) is int assert b != a @@ -39,10 +39,10 @@ def test_reversed(): ('abcdefg', 'gfedcba'), ([], []), ([1], [1]), - ([1,2], [2,1]), - ([1,2,3], [3,2,1]), + ([1, 2], [2, 1]), + ([1, 2, 3], [3, 2, 1]), (list(range(1000)), list(range(999, -1, -1))), - ] + ] for case, expected in cases: expected = list(expected) @@ -64,16 +64,9 @@ def test_binaryLongConvert(): n_prime = cryptutil.binaryToLong(s) assert n == n_prime, (n, n_prime) - cases = [ - (b'\x00', 0), - (b'\x01', 1), - (b'\x7F', 127), - (b'\x00\xFF', 255), - (b'\x00\x80', 128), - (b'\x00\x81', 129), - (b'\x00\x80\x00', 32768), - (b'OpenID is cool', 1611215304203901150134421257416556) - ] + cases = [(b'\x00', 0), (b'\x01', 1), (b'\x7F', 127), (b'\x00\xFF', 255), + (b'\x00\x80', 128), (b'\x00\x81', 129), (b'\x00\x80\x00', 32768), + (b'OpenID is cool', 1611215304203901150134421257416556)] for s, n in cases: n_prime = cryptutil.binaryToLong(s) @@ -111,5 +104,6 @@ def test(): test_longToBase64() test_base64ToLong() + if __name__ == '__main__': test() diff --git a/openid/test/dh.py b/openid/test/dh.py index 1a851e87ec70fb08e7bfb1f39de1226bf4598fd5..0b103e0d53c662673e0fd97324323ddfcd59c8a5 100644 --- a/openid/test/dh.py +++ b/openid/test/dh.py @@ -15,7 +15,7 @@ def test_strxor(): (b'\x01', b'\x02', b'\x03'), (b'\xf0', b'\x0f', b'\xff'), (b'\xff', b'\x0f', b'\xf0'), - ] + ] for aa, bb, expected in cases: actual = strxor(aa, bb) @@ -25,9 +25,8 @@ def test_strxor(): ('', 'a'), ('foo', 'ba'), (NUL * 3, NUL * 4), - (''.join(map(chr, range(256))), - ''.join(map(chr, range(128)))), - ] + (''.join(map(chr, range(256))), ''.join(map(chr, range(128)))), + ] for aa, bb in exc_cases: try: @@ -35,7 +34,7 @@ def test_strxor(): except ValueError: pass else: - assert False, 'Expected ValueError, got %r' % (unexpected,) + assert False, 'Expected ValueError, got %r' % (unexpected, ) def test1(): @@ -71,5 +70,6 @@ def test(): test_public() test_strxor() + if __name__ == '__main__': test() diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 3da6d51769f53dbb886ae9666fe1aeb2d1fd582d..75bf16ddaddb10f7ddb27d8568521d478284042d 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -9,25 +9,25 @@ tests_dir = os.path.dirname(__file__) data_path = os.path.join(tests_dir, 'data') testlist = [ -# success, input_name, id_name, result_name - (True, "equiv", "equiv", "xrds"), - (True, "header", "header", "xrds"), - (True, "lowercase_header", "lowercase_header", "xrds"), - (True, "xrds", "xrds", "xrds"), - (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), - (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), - (False, "xrds_html", "xrds_html", "xrds_html"), - (True, "redir_equiv", "equiv", "xrds"), - (True, "redir_header", "header", "xrds"), - (True, "redir_xrds", "xrds", "xrds"), - (False, "redir_xrds_html", "xrds_html", "xrds_html"), - (True, "redir_redir_equiv", "equiv", "xrds"), - (False, "404_server_response", None, None), - (False, "404_with_header", None, None), - (False, "404_with_meta", None, None), - (False, "201_server_response", None, None), - (False, "500_server_response", None, None), - ] + # success, input_name, id_name, result_name + (True, "equiv", "equiv", "xrds"), + (True, "header", "header", "xrds"), + (True, "lowercase_header", "lowercase_header", "xrds"), + (True, "xrds", "xrds", "xrds"), + (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), + (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), + (False, "xrds_html", "xrds_html", "xrds_html"), + (True, "redir_equiv", "equiv", "xrds"), + (True, "redir_header", "header", "xrds"), + (True, "redir_xrds", "xrds", "xrds"), + (False, "redir_xrds_html", "xrds_html", "xrds_html"), + (True, "redir_redir_equiv", "equiv", "xrds"), + (False, "404_server_response", None, None), + (False, "404_with_header", None, None), + (False, "404_with_meta", None, None), + (False, "201_server_response", None, None), + (False, "500_server_response", None, None), +] def getDataName(*components): @@ -49,6 +49,7 @@ def getExampleXRDS(): with open(filename) as f: return f.read() + example_xrds = getExampleXRDS() default_test_file = getDataName('test1-discover.txt') @@ -80,7 +81,7 @@ def fillTemplate(test_name, template, base_url, example_xrds): ('<XRDS Content>', example_xrds), ('YADIS_HEADER', YADIS_HEADER_NAME), ('NAME', test_name), - ] + ] for k, v in mapping: template = template.replace(k, v) @@ -88,7 +89,8 @@ def fillTemplate(test_name, template, base_url, example_xrds): return template -def generateSample(test_name, base_url, +def generateSample(test_name, + base_url, example_xrds=example_xrds, filename=default_test_file): try: diff --git a/openid/test/kvform.py b/openid/test/kvform.py index cfe3333b498f5471efd361725f652fde07554e8c..cf1d09511e5a0c3905fe641c9dc28307dbe9145c 100644 --- a/openid/test/kvform.py +++ b/openid/test/kvform.py @@ -79,15 +79,21 @@ class KVSeqTest(KVBaseTest): self.assertEqual(seq, clean_seq) self.checkWarnings(self.expected_warnings) + kvdict_cases = [ # (kvform, parsed dictionary, expected warnings) ('', {}, 0), - ('college:harvey mudd\n', {'college':'harvey mudd'}, 0), - ('city:claremont\nstate:CA\n', - {'city':'claremont', 'state':'CA'}, 0), - ('is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n', - {'is_valid':'true', - 'invalidate_handle':'{HMAC-SHA1:2398410938412093}'}, 0), + ('college:harvey mudd\n', { + 'college': 'harvey mudd' + }, 0), + ('city:claremont\nstate:CA\n', { + 'city': 'claremont', + 'state': 'CA' + }, 0), + ('is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n', { + 'is_valid': 'true', + 'invalidate_handle': '{HMAC-SHA1:2398410938412093}' + }, 0), # Warnings from lines with no colon: ('x\n', {}, 1), @@ -98,18 +104,33 @@ kvdict_cases = [ ('x\n\n', {}, 1), # Warning from empty key - (':\n', {'':''}, 1), - (':missing key\n', {'':'missing key'}, 1), + (':\n', { + '': '' + }, 1), + (':missing key\n', { + '': 'missing key' + }, 1), # Warnings from leading or trailing whitespace in key or value - (' street:foothill blvd\n', {'street':'foothill blvd'}, 1), - ('major: computer science\n', {'major':'computer science'}, 1), - (' dorm : east \n', {'dorm':'east'}, 2), + (' street:foothill blvd\n', { + 'street': 'foothill blvd' + }, 1), + ('major: computer science\n', { + 'major': 'computer science' + }, 1), + (' dorm : east \n', { + 'dorm': 'east' + }, 2), # Warnings from missing trailing newline - ('e^(i*pi)+1:0', {'e^(i*pi)+1':'0'}, 1), - ('east:west\nnorth:south', {'east':'west', 'north':'south'}, 1), - ] + ('e^(i*pi)+1:0', { + 'e^(i*pi)+1': '0' + }, 1), + ('east:west\nnorth:south', { + 'east': 'west', + 'north': 'south' + }, 1), +] kvseq_cases = [ ([], '', 0), @@ -120,23 +141,21 @@ kvseq_cases = [ # If it's a UTF-8 str, make sure that it's equivalent to the same # string, decoded. ([('\xce\xbbx', 'x')], '\xce\xbbx:x\n', 0), - ([('openid', 'useful'), ('a', 'b')], 'openid:useful\na:b\n', 0), # Warnings about leading whitespace ([(' openid', 'useful'), ('a', 'b')], ' openid:useful\na:b\n', 2), # Warnings about leading and trailing whitespace - ([(' openid ', ' useful '), - (' a ', ' b ')], ' openid : useful \n a : b \n', 8), + ([(' openid ', ' useful '), (' a ', ' b ')], + ' openid : useful \n a : b \n', 8), # warnings about leading and trailing whitespace, but not about # internal whitespace. - ([(' open id ', ' use ful '), - (' a ', ' b ')], ' open id : use ful \n a : b \n', 8), - + ([(' open id ', ' use ful '), (' a ', ' b ')], + ' open id : use ful \n a : b \n', 8), ([('foo', 'bar')], 'foo:bar\n', 0), - ] +] kvexc_cases = [ [('openid', 'use\nful')], @@ -145,7 +164,7 @@ kvexc_cases = [ [('open:id', 'useful')], [('foo', 'bar'), ('ba\n d', 'seed')], [('foo', 'bar'), ('bad:', 'seed')], - ] +] class KVExcTest(unittest.TestCase): @@ -154,7 +173,7 @@ class KVExcTest(unittest.TestCase): self.seq = seq def shortDescription(self): - return 'KVExcTest for %r' % (self.seq,) + return 'KVExcTest for %r' % (self.seq, ) def runTest(self): self.assertRaises(ValueError, kvform.seqToKV, self.seq) @@ -176,6 +195,7 @@ def pyUnitTests(): tests.append(unittest.defaultTestLoader.loadTestsFromTestCase(GeneralTest)) return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/linkparse.py b/openid/test/linkparse.py index f6f8313f14b81a5122a7337a8b3bdee3d1e585c8..8d45c65fa4ed6f4e664729792e0e0c1282dc9ee9 100644 --- a/openid/test/linkparse.py +++ b/openid/test/linkparse.py @@ -108,6 +108,7 @@ def pyUnitTests(): return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/oidutil.py b/openid/test/oidutil.py index 8b5cfd597faeb188d6099d4eb1c89bc345adfc7c..ac4b62bbca55d2081824c46f71d6ea744ad5745c 100644 --- a/openid/test/oidutil.py +++ b/openid/test/oidutil.py @@ -24,7 +24,7 @@ def test_base64(): '\x01', '\x00' * 100, ''.join(map(chr, list(range(256)))), - ] + ] for s in cases: b64 = oidutil.toBase64(s) @@ -58,7 +58,6 @@ class AppendArgsTest(unittest.TestCase): class TestUnicodeConversion(unittest.TestCase): - def test_toUnicode(self): # Unicode objects pass through self.assertTrue(isinstance(oidutil.toUnicode('fööbär'), str)) @@ -87,78 +86,48 @@ class TestSymbol(unittest.TestCase): def buildAppendTests(): simple = 'http://www.example.com/' cases = [ - ('empty list', - (simple, []), - simple), - - ('empty dict', - (simple, {}), - simple), - - ('one list', - (simple, [('a', 'b')]), - simple + '?a=b'), - - ('one dict', - (simple, {'a':'b'}), - simple + '?a=b'), - - ('two list (same)', - (simple, [('a', 'b'), ('a', 'c')]), + ('empty list', (simple, []), simple), + ('empty dict', (simple, {}), simple), + ('one list', (simple, [('a', 'b')]), simple + '?a=b'), + ('one dict', (simple, { + 'a': 'b' + }), simple + '?a=b'), + ('two list (same)', (simple, [('a', 'b'), ('a', 'c')]), simple + '?a=b&a=c'), - - ('two list', - (simple, [('a', 'b'), ('b', 'c')]), - simple + '?a=b&b=c'), - - ('two list (order)', - (simple, [('b', 'c'), ('a', 'b')]), + ('two list', (simple, [('a', 'b'), ('b', 'c')]), simple + '?a=b&b=c'), + ('two list (order)', (simple, [('b', 'c'), ('a', 'b')]), simple + '?b=c&a=b'), - - ('two dict (order)', - (simple, {'b':'c', 'a':'b'}), - simple + '?a=b&b=c'), - - ('escape', - (simple, [('=', '=')]), - simple + '?%3D=%3D'), - - ('escape (URL)', - (simple, [('this_url', simple)]), + ('two dict (order)', (simple, { + 'b': 'c', + 'a': 'b' + }), simple + '?a=b&b=c'), + ('escape', (simple, [('=', '=')]), simple + '?%3D=%3D'), + ('escape (URL)', (simple, [('this_url', simple)]), simple + '?this_url=http%3A%2F%2Fwww.example.com%2F'), - - ('use dots', - (simple, [('openid.stuff', 'bother')]), + ('use dots', (simple, [('openid.stuff', 'bother')]), simple + '?openid.stuff=bother'), - - ('args exist (empty)', - (simple + '?stuff=bother', []), + ('args exist (empty)', (simple + '?stuff=bother', []), simple + '?stuff=bother'), - - ('args exist', - (simple + '?stuff=bother', [('ack', 'ack')]), - simple + '?stuff=bother&ack=ack'), - - ('args exist', - (simple + '?stuff=bother', [('ack', 'ack')]), + ('args exist', (simple + '?stuff=bother', [('ack', 'ack')]), simple + '?stuff=bother&ack=ack'), - - ('args exist (dict)', - (simple + '?stuff=bother', {'ack': 'ack'}), + ('args exist', (simple + '?stuff=bother', [('ack', 'ack')]), simple + '?stuff=bother&ack=ack'), - - ('args exist (dict 2)', - (simple + '?stuff=bother', {'ack': 'ack', 'zebra':'lion'}), - simple + '?stuff=bother&ack=ack&zebra=lion'), - - ('three args (dict)', - (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra':'lion'}), - simple + '?ack=ack&stuff=bother&zebra=lion'), - + ('args exist (dict)', (simple + '?stuff=bother', { + 'ack': 'ack' + }), simple + '?stuff=bother&ack=ack'), + ('args exist (dict 2)', (simple + '?stuff=bother', { + 'ack': 'ack', + 'zebra': 'lion' + }), simple + '?stuff=bother&ack=ack&zebra=lion'), + ('three args (dict)', (simple, { + 'stuff': 'bother', + 'ack': 'ack', + 'zebra': 'lion' + }), simple + '?ack=ack&stuff=bother&zebra=lion'), ('three args (list)', (simple, [('stuff', 'bother'), ('ack', 'ack'), ('zebra', 'lion')]), simple + '?stuff=bother&ack=ack&zebra=lion'), - ] + ] tests = [] @@ -172,7 +141,9 @@ def buildAppendTests(): def pyUnitTests(): some = buildAppendTests() some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) - some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestUnicodeConversion)) + some.addTest( + unittest.defaultTestLoader.loadTestsFromTestCase( + TestUnicodeConversion)) return some @@ -188,10 +159,12 @@ def test_appendArgs(): # specified and tested in oidutil.py These include, but are not # limited to appendArgs + def test(skipPyUnit=True): test_base64() if not skipPyUnit: test_appendArgs() + if __name__ == '__main__': test(skipPyUnit=False) diff --git a/openid/test/storetest.py b/openid/test/storetest.py index 9aadb245f5197a5ddfa1aa6910bd2ef85d51c9a4..347126198cdb9ed7e1c3d998275a02bc6d0cd5ee 100644 --- a/openid/test/storetest.py +++ b/openid/test/storetest.py @@ -14,7 +14,6 @@ from openid.association import Association from openid.cryptutil import randomString from openid.store.nonce import mkNonce, split - db_host = os.environ.get('TEST_DB_HOST', 'dbtest') allowed_handle = [] @@ -60,11 +59,12 @@ def testStore(store): def checkRetrieve(url, handle=None, expected=None): retrieved_assoc = store.getAssociation(url, handle) - assert retrieved_assoc == expected, (retrieved_assoc.__dict__, expected.__dict__) + assert retrieved_assoc == expected, (retrieved_assoc.__dict__, + expected.__dict__) if expected is not None: if retrieved_assoc is expected: - print ('Unexpected: retrieved a reference to the expected ' - 'value instead of a new object') + print('Unexpected: retrieved a reference to the expected ' + 'value instead of a new object') assert retrieved_assoc.handle == expected.handle assert retrieved_assoc.secret == expected.secret @@ -200,7 +200,7 @@ def testStore(store): # Nonces from when the universe was an hour old should not pass now. old_nonce = mkNonce(3600) checkUseNonce(old_nonce, False, url, - "Old nonce (%r) passed." % (old_nonce,)) + "Old nonce (%r) passed." % (old_nonce, )) old_nonce1 = mkNonce(now - 20000) old_nonce2 = mkNonce(now - 10000) @@ -219,7 +219,7 @@ def testStore(store): nonceModule.SKEW = 3600 cleaned = store.cleanupNonces() - assert cleaned == 2, "Cleaned %r nonces." % (cleaned,) + assert cleaned == 2, "Cleaned %r nonces." % (cleaned, ) nonceModule.SKEW = 100000 # A roundabout method of checking that the old nonces were cleaned is @@ -271,13 +271,13 @@ def test_mysql(): # Change this connect line to use the right user and password try: - conn = MySQLdb.connect(user=db_user, passwd=db_passwd, - host=db_host) + conn = MySQLdb.connect( + user=db_user, passwd=db_passwd, host=db_host) except MySQLdb.OperationalError as why: if why.args[0] == 2005: - raise unittest.SkipTest('Skipping MySQL store test. ' - 'Cannot connect to server on host %r.' - % (db_host,)) + raise unittest.SkipTest( + 'Skipping MySQL store test. ' + 'Cannot connect to server on host %r.' % (db_host, )) else: raise @@ -337,22 +337,21 @@ def test_postgresql(): # Connect once to create the database; reconnect to access the # new database. try: - conn_create = psycopg2.connect(database='template1', user=db_user, - host=db_host) + conn_create = psycopg2.connect( + database='template1', user=db_user, host=db_host) except psycopg2.OperationalError as why: - raise unittest.SkipTest('Skipping PostgreSQL store test: %s' - % why) + raise unittest.SkipTest('Skipping PostgreSQL store test: %s' % why) conn_create.autocommit = True # Create the test database. cursor = conn_create.cursor() - cursor.execute('CREATE DATABASE %s;' % (db_name,)) + cursor.execute('CREATE DATABASE %s;' % (db_name, )) conn_create.close() # Connect to the test database. - conn_test = psycopg2.connect(database=db_name, user=db_user, - host=db_host) + conn_test = psycopg2.connect( + database=db_name, user=db_user, host=db_host) # OK, we're in the right environment. Create the store # instance and create the tables. @@ -373,12 +372,12 @@ def test_postgresql(): time.sleep(1) # Remove the database now that the test is over. - conn_remove = psycopg2.connect(database='template1', user=db_user, - host=db_host) + conn_remove = psycopg2.connect( + database='template1', user=db_user, host=db_host) conn_remove.autocommit = True cursor = conn_remove.cursor() - cursor.execute('DROP DATABASE %s;' % (db_name,)) + cursor.execute('DROP DATABASE %s;' % (db_name, )) conn_remove.close() @@ -386,19 +385,21 @@ def test_memstore(): from openid.store import memstore testStore(memstore.MemoryStore()) + test_functions = [ test_filestore, test_sqlite, test_mysql, test_postgresql, test_memstore, - ] +] def pyUnitTests(): tests = list(map(unittest.FunctionTestCase, test_functions)) return unittest.TestSuite(tests) + if __name__ == '__main__': import sys suite = pyUnitTests() diff --git a/openid/test/support.py b/openid/test/support.py index bc8a4b9a09e2e0856c85ea6155de550be2d19ddf..b9cf73d4efabbb0d54331e6a70a9e93a00e67219 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -2,6 +2,7 @@ from openid import message from logging.handlers import BufferingHandler import logging + class TestHandler(BufferingHandler): def __init__(self, messages): BufferingHandler.__init__(self, 0) @@ -13,6 +14,7 @@ class TestHandler(BufferingHandler): def emit(self, record): self.messages.append(record.__dict__) + class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): if ns is None: @@ -31,6 +33,7 @@ class OpenIDTestMixin(object): error_message = 'openid.%s unexpectedly present: %s' % (key, actual) self.assertFalse(actual is not None, error_message) + class CatchLogs(object): def setUp(self): self.messages = [] @@ -39,7 +42,8 @@ class CatchLogs(object): root_logger.setLevel(logging.DEBUG) self.handler = TestHandler(self.messages) - formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") + formatter = logging.Formatter( + "%(message)s [%(asctime)s - %(name)s - %(levelname)s]") self.handler.setFormatter(formatter) root_logger.addHandler(self.handler) diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 739f096c55513612cd907ecff74d95b07afa3add..a4084ea37bff126cad08669515385a13d1a32fca 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -2,6 +2,7 @@ import unittest import os.path from openid.yadis import accept + def getTestData(): """Read the test data off of disk @@ -15,6 +16,7 @@ def getTestData(): i += 1 return lines + def chunk(lines): """Return groups of lines separated by whitespace or comments @@ -36,6 +38,7 @@ def chunk(lines): return chunks + def parseLines(chunk): """Take the given chunk of lines and turn it into a test data dictionary @@ -49,6 +52,7 @@ def parseLines(chunk): return items + def parseAvailable(available_text): """Parse an Available: line's data @@ -56,6 +60,7 @@ def parseAvailable(available_text): """ return [s.strip() for s in available_text.split(',')] + def parseExpected(expected_text): """Parse an Expected: line's data @@ -76,6 +81,7 @@ def parseExpected(expected_text): return expected + class MatchAcceptTest(unittest.TestCase): def __init__(self, descr, accept_header, available, expected): unittest.TestCase.__init__(self) @@ -92,6 +98,7 @@ class MatchAcceptTest(unittest.TestCase): actual = accept.matchTypes(accepted, self.available) self.assertEqual(self.expected, actual) + def pyUnitTests(): lines = getTestData() chunks = chunk(lines) @@ -117,11 +124,12 @@ def pyUnitTests(): print('On line', lno) raise - descr = 'MatchAcceptTest for lines %r' % (lnos,) + descr = 'MatchAcceptTest for lines %r' % (lnos, ) case = MatchAcceptTest(descr, header, available, expected) cases.append(case) return unittest.TestSuite(cases) + if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(pyUnitTests()) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 71f515c15bfdcc5a32ea3561b7499c438436ae3c..e752fdc9826bef87e8404d8c590619243e95b61b 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -14,8 +14,8 @@ class AssociationSerializationTest(unittest.TestCase): issued = int(time.time()) lifetime = 600 handle = 'a-QoU6tM*#!*R\'q\\w<W>X`90>tj7d{[t~Wv@(j(V9(jcx:ZeGYbT0;N]"C}bxQ$aDjf{)"z6@+W<Wb$Vm`k9j0/tZ=\\J[0Qmp35ex[H9g<nUC9UGj4.Hlq7"Q]`w:w6Q' - assoc = association.Association( - handle, 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association(handle, 'secret', issued, lifetime, + 'HMAC-SHA1') s = assoc.serialize() assoc2 = association.Association.deserialize(s) self.assertEqual(assoc.handle, assoc2.handle) @@ -48,18 +48,17 @@ class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): '\xff' * 20, ' ' * 20, 'This is a secret....', - ] + ] session_factories = [ (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA1ServerSession), (createNonstandardConsumerDH, DiffieHellmanSHA1ServerSession), (PlainTextConsumerSession, PlainTextServerSession), - ] + ] def generateCases(cls): return [(c, s, sec) - for c, s in cls.session_factories - for sec in cls.secrets] + for c, s in cls.session_factories for sec in cls.secrets] generateCases = classmethod(generateCases) @@ -89,7 +88,7 @@ class TestMakePairs(unittest.TestCase): 'identifier': '=example', 'signed': 'identifier,mode', 'sig': 'cephalopod', - }) + }) m.updateArgs(BARE_NS, {'xey': 'value'}) self.assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -100,14 +99,13 @@ class TestMakePairs(unittest.TestCase): expected = [ ('identifier', '=example'), ('mode', 'id_res'), - ] + ] self.assertEqual(pairs, expected) class TestMac(unittest.TestCase): def setUp(self): - self.pairs = [('key1', 'value1'), - ('key2', 'value2')] + self.pairs = [('key1', 'value1'), ('key2', 'value2')] def test_sha1(self): assoc = association.Association.fromExpiresIn( @@ -118,6 +116,7 @@ class TestMac(unittest.TestCase): self.assertEqual(sig, expected) if cryptutil.SHA256_AVAILABLE: + def test_sha256(self): assoc = association.Association.fromExpiresIn( 3600, '{sha256SA}', 'very_secret', "HMAC-SHA256") @@ -130,42 +129,45 @@ class TestMac(unittest.TestCase): class TestMessageSigning(unittest.TestCase): def setUp(self): self.message = m = Message(OPENID2_NS) - m.updateArgs(OPENID2_NS, {'mode': 'id_res', - 'identifier': '=example'}) + m.updateArgs(OPENID2_NS, {'mode': 'id_res', 'identifier': '=example'}) m.updateArgs(BARE_NS, {'xey': 'value'}) - self.args = {'openid.mode': 'id_res', - 'openid.identifier': '=example', - 'xey': 'value'} + self.args = { + 'openid.mode': 'id_res', + 'openid.identifier': '=example', + 'xey': 'value' + } def test_signSHA1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") signed = assoc.signMessage(self.message) self.assertTrue(signed.getArg(OPENID_NS, "sig")) - self.assertEqual(signed.getArg(OPENID_NS, "signed"), - "assoc_handle,identifier,mode,ns,signed") - self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", - signed) + self.assertEqual( + signed.getArg(OPENID_NS, "signed"), + "assoc_handle,identifier,mode,ns,signed") + self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", signed) if cryptutil.SHA256_AVAILABLE: + def test_signSHA256(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA256") signed = assoc.signMessage(self.message) self.assertTrue(signed.getArg(OPENID_NS, "sig")) - self.assertEqual(signed.getArg(OPENID_NS, "signed"), - "assoc_handle,identifier,mode,ns,signed") - self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", - signed) + self.assertEqual( + signed.getArg(OPENID_NS, "signed"), + "assoc_handle,identifier,mode,ns,signed") + self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", signed) class TestCheckMessageSignature(unittest.TestCase): def test_aintGotSignedList(self): m = Message(OPENID2_NS) - m.updateArgs(OPENID2_NS, {'mode': 'id_res', - 'identifier': '=example', - 'sig': 'coyote', - }) + m.updateArgs(OPENID2_NS, { + 'mode': 'id_res', + 'identifier': '=example', + 'sig': 'coyote', + }) m.updateArgs(BARE_NS, {'xey': 'value'}) assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -175,6 +177,7 @@ class TestCheckMessageSignature(unittest.TestCase): def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index fcd394505c1c815877afdd9cf5e653131ddd9c41..a738449982c54a624b3d2cb1419861f0dcedf83a 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -48,7 +48,7 @@ class BaseAssocTest(CatchLogs, unittest.TestCase): message = 'Expected prefix %r, got %r' % (str_prefix, e_arg) self.assertTrue(e_arg.startswith(str_prefix), message) else: - self.fail('Expected ProtocolError, got %r' % (result,)) + self.fail('Expected ProtocolError, got %r' % (result, )) def mkExtractAssocMissingTest(keys): @@ -75,8 +75,8 @@ def mkExtractAssocMissingTest(keys): def test(self): msg = mkAssocResponse(*keys) - self.assertRaises(KeyError, - self.consumer._extractAssociation, msg, None) + self.assertRaises(KeyError, self.consumer._extractAssociation, msg, + None) return test @@ -132,47 +132,41 @@ class ExtractAssociationSessionTypeMismatch(BaseAssocTest): msg = mkAssocResponse(*keys) msg.setArg(OPENID_NS, 'session_type', response_session_type) self.failUnlessProtocolError('Session type mismatch', - self.consumer._extractAssociation, msg, assoc_session) + self.consumer._extractAssociation, + msg, assoc_session) return test test_typeMismatchNoEncBlank_openid2 = mkTest( requested_session_type='no-encryption', - response_session_type='', - ) + response_session_type='', ) test_typeMismatchDHSHA1NoEnc_openid2 = mkTest( requested_session_type='DH-SHA1', - response_session_type='no-encryption', - ) + response_session_type='no-encryption', ) test_typeMismatchDHSHA256NoEnc_openid2 = mkTest( requested_session_type='DH-SHA256', - response_session_type='no-encryption', - ) + response_session_type='no-encryption', ) test_typeMismatchNoEncDHSHA1_openid2 = mkTest( requested_session_type='no-encryption', - response_session_type='DH-SHA1', - ) + response_session_type='DH-SHA1', ) test_typeMismatchDHSHA1NoEnc_openid1 = mkTest( requested_session_type='DH-SHA1', response_session_type='DH-SHA256', - openid1=True, - ) + openid1=True, ) test_typeMismatchDHSHA256NoEnc_openid1 = mkTest( requested_session_type='DH-SHA256', response_session_type='DH-SHA1', - openid1=True, - ) + openid1=True, ) test_typeMismatchNoEncDHSHA1_openid1 = mkTest( requested_session_type='no-encryption', response_session_type='DH-SHA1', - openid1=True, - ) + openid1=True, ) class TestOpenID1AssociationResponseSessionType(BaseAssocTest): @@ -181,6 +175,7 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): be used if the OpenID 1 response to an associate call sets the 'session_type' field to `session_type_value` """ + def test(self): self._doTest(expected_session_type, session_type_value) self.assertEqual(0, len(self.messages)) @@ -202,35 +197,31 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): 'to yield session type %r, but yielded %r' % (session_type_value, expected_session_type, actual_session_type)) - self.assertEqual( - expected_session_type, actual_session_type, error_message) + self.assertEqual(expected_session_type, actual_session_type, + error_message) test_none = mkTest( session_type_value=None, - expected_session_type='no-encryption', - ) + expected_session_type='no-encryption', ) test_empty = mkTest( session_type_value='', - expected_session_type='no-encryption', - ) + expected_session_type='no-encryption', ) # This one's different because it expects log messages def test_explicitNoEncryption(self): self._doTest( session_type_value='no-encryption', - expected_session_type='no-encryption', - ) + expected_session_type='no-encryption', ) self.assertEqual(1, len(self.messages)) log_msg = self.messages[0] self.assertEqual(log_msg['levelname'], 'WARNING') - self.assertTrue(log_msg['msg'].startswith( - 'OpenID server sent "no-encryption"')) + self.assertTrue( + log_msg['msg'].startswith('OpenID server sent "no-encryption"')) test_dhSHA1 = mkTest( session_type_value='DH-SHA1', - expected_session_type='DH-SHA1', - ) + expected_session_type='DH-SHA1', ) # DH-SHA256 is not a valid session type for OpenID1, but this # function does not test that. This is mostly just to make sure @@ -239,8 +230,7 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): # 2 test_dhSHA256 = mkTest( session_type_value='DH-SHA256', - expected_session_type='DH-SHA256', - ) + expected_session_type='DH-SHA256', ) class DummyAssociationSession(object): @@ -268,12 +258,17 @@ class TestInvalidFields(BaseAssocTest): # These arguments should all be valid self.assoc_response = Message.fromOpenIDArgs({ - 'expires_in': '1000', - 'assoc_handle': self.assoc_handle, - 'assoc_type': self.assoc_type, - 'session_type': self.session_type, - 'ns': OPENID2_NS, - }) + 'expires_in': + '1000', + 'assoc_handle': + self.assoc_handle, + 'assoc_type': + self.assoc_type, + 'session_type': + self.session_type, + 'ns': + OPENID2_NS, + }) self.assoc_session = DummyAssociationSession() @@ -283,8 +278,8 @@ class TestInvalidFields(BaseAssocTest): def test_worksWithGoodFields(self): """Handle a full successful association response""" - assoc = self.consumer._extractAssociation( - self.assoc_response, self.assoc_session) + assoc = self.consumer._extractAssociation(self.assoc_response, + self.assoc_session) self.assertTrue(self.assoc_session.extract_secret_called) self.assertEqual(self.assoc_session.secret, assoc.secret) self.assertEqual(1000, assoc.lifetime) @@ -296,15 +291,15 @@ class TestInvalidFields(BaseAssocTest): # for the given session. self.assoc_session.allowed_assoc_types = [] self.failUnlessProtocolError('Unsupported assoc_type for session', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, + self.assoc_response, self.assoc_session) def test_badExpiresIn(self): # Invalid value for expires_in should cause failure self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever') self.failUnlessProtocolError('Invalid expires_in', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, + self.assoc_response, self.assoc_session) # XXX: This is what causes most of the imports in this file. It is @@ -319,7 +314,7 @@ class TestExtractAssociationDiffieHellman(BaseAssocTest): # XXX: this is testing _createAssociateRequest self.assertEqual(self.endpoint.compatibilityMode(), - message.isOpenID1()) + message.isOpenID1()) server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message) server_resp = server_sess.answer(self.secret) @@ -348,4 +343,5 @@ class TestExtractAssociationDiffieHellman(BaseAssocTest): sess, server_resp = self._setUpDH() server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') self.failUnlessProtocolError('Malformed response for', - self.consumer._extractAssociation, server_resp, sess) + self.consumer._extractAssociation, + server_resp, sess) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 378a50fc028ba4b2d03bd98e7e33bc040a73d98b..53f5fabb2c6e0142ab474b06c49c484f52c0f849 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -50,13 +50,11 @@ class TestAuthRequestMixin(support.OpenIDTestMixin): def failUnlessHasRequiredFields(self, msg): self.assertEqual(self.preferred_namespace, - self.authreq.message.getOpenIDNamespace()) + self.authreq.message.getOpenIDNamespace()) - self.assertEqual(self.preferred_namespace, - msg.getOpenIDNamespace()) + self.assertEqual(self.preferred_namespace, msg.getOpenIDNamespace()) - self.failUnlessOpenIDValueEquals(msg, 'mode', - self.expected_mode) + self.failUnlessOpenIDValueEquals(msg, 'mode', self.expected_mode) # Implement these in subclasses because they depend on # protocol differences! @@ -83,9 +81,10 @@ class TestAuthRequestMixin(support.OpenIDTestMixin): self.authreq.addExtensionArg('bag:', 'color', 'brown') self.authreq.addExtensionArg('bag:', 'material', 'paper') self.assertTrue('bag:' in self.authreq.message.namespaces) - self.assertEqual(self.authreq.message.getArgs('bag:'), - {'color': 'brown', - 'material': 'paper'}) + self.assertEqual( + self.authreq.message.getArgs('bag:'), + {'color': 'brown', + 'material': 'paper'}) msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) @@ -100,8 +99,8 @@ class TestAuthRequestMixin(support.OpenIDTestMixin): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasIdentifiers( - msg, self.endpoint.local_id, self.endpoint.claimed_id) + self.failUnlessHasIdentifiers(msg, self.endpoint.local_id, + self.endpoint.claimed_id) class TestAuthRequestOpenID2(TestAuthRequestMixin, unittest.TestCase): @@ -151,8 +150,8 @@ class TestAuthRequestOpenID2(TestAuthRequestMixin, unittest.TestCase): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) self.failUnlessHasRequiredFields(msg) - self.failUnlessHasIdentifiers( - msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) + self.failUnlessHasIdentifiers(msg, message.IDENTIFIER_SELECT, + message.IDENTIFIER_SELECT) class TestAuthRequestOpenID1(TestAuthRequestMixin, unittest.TestCase): @@ -196,7 +195,7 @@ class TestAuthRequestOpenID1(TestAuthRequestMixin, unittest.TestCase): self.immediate) self.failUnlessHasRequiredFields(msg) self.assertEqual(message.IDENTIFIER_SELECT, - msg.getArg(message.OPENID1_NS, 'identity')) + msg.getArg(message.OPENID1_NS, 'identity')) class TestAuthRequestOpenID1Immediate(TestAuthRequestOpenID1): @@ -208,5 +207,6 @@ class TestAuthRequestOpenID2Immediate(TestAuthRequestOpenID2): immediate = True expected_mode = 'checkid_immediate' + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 8d4aaf68712366baa03cab38238d68ce80829a1d..0702a8256e2f454bff995ed0389c96bb72302b59 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -58,9 +58,8 @@ class ToTypeURIsTest(unittest.TestCase): self.assertEqual([], uris) def test_undefined(self): - self.assertRaises( - KeyError, - ax.toTypeURIs, self.aliases, 'http://janrain.com/') + self.assertRaises(KeyError, ax.toTypeURIs, self.aliases, + 'http://janrain.com/') def test_one(self): uri = 'http://janrain.com/' @@ -102,24 +101,23 @@ class ParseAXValuesTest(unittest.TestCase): self.failUnlessAXKeyError({'type.foo': 'urn:foo'}) def test_countPresentButNotValue(self): - self.failUnlessAXKeyError({'type.foo': 'urn:foo', - 'count.foo': '1'}) + self.failUnlessAXKeyError({'type.foo': 'urn:foo', 'count.foo': '1'}) def test_invalidCountValue(self): msg = ax.FetchRequest() - self.assertRaises(ax.AXError, - msg.parseExtensionArgs, - {'type.foo': 'urn:foo', - 'count.foo': 'bogus'}) + self.assertRaises(ax.AXError, msg.parseExtensionArgs, + {'type.foo': 'urn:foo', + 'count.foo': 'bogus'}) def test_requestUnlimitedValues(self): msg = ax.FetchRequest() - msg.parseExtensionArgs( - {'mode': 'fetch_request', - 'required': 'foo', - 'type.foo': 'urn:foo', - 'count.foo': ax.UNLIMITED_VALUES}) + msg.parseExtensionArgs({ + 'mode': 'fetch_request', + 'required': 'foo', + 'type.foo': 'urn:foo', + 'count.foo': ax.UNLIMITED_VALUES + }) attrs = list(msg.iterAttrs()) foo = attrs[0] @@ -133,64 +131,65 @@ class ParseAXValuesTest(unittest.TestCase): alias = 'x' * ax.MINIMUM_SUPPORTED_ALIAS_LENGTH msg = ax.AXKeyValueMessage() - msg.parseExtensionArgs( - {'type.%s' % (alias,): 'urn:foo', - 'count.%s' % (alias,): '1', - 'value.%s.1' % (alias,): 'first'} - ) + msg.parseExtensionArgs({ + 'type.%s' % (alias, ): 'urn:foo', + 'count.%s' % (alias, ): '1', + 'value.%s.1' % (alias, ): 'first' + }) def test_invalidAlias(self): - types = [ - ax.AXKeyValueMessage, - ax.FetchRequest - ] + types = [ax.AXKeyValueMessage, ax.FetchRequest] inputs = [ - {'type.a.b':'urn:foo', - 'count.a.b':'1'}, - {'type.a,b':'urn:foo', - 'count.a,b':'1'}, - ] + { + 'type.a.b': 'urn:foo', + 'count.a.b': '1' + }, + { + 'type.a,b': 'urn:foo', + 'count.a,b': '1' + }, + ] for typ in types: for input in inputs: msg = typ() - self.assertRaises(ax.AXError, msg.parseExtensionArgs, - input) + self.assertRaises(ax.AXError, msg.parseExtensionArgs, input) def test_countPresentAndIsZero(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'count.foo': '0', - }, {'urn:foo': []}) + self.failUnlessAXValues({ + 'type.foo': 'urn:foo', + 'count.foo': '0', + }, {'urn:foo': []}) def test_singletonEmpty(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': '', - }, {'urn:foo': []}) + self.failUnlessAXValues({ + 'type.foo': 'urn:foo', + 'value.foo': '', + }, {'urn:foo': []}) def test_doubleAlias(self): - self.failUnlessAXKeyError( - {'type.foo': 'urn:foo', - 'value.foo': '', - 'type.bar': 'urn:foo', - 'value.bar': '', - }) + self.failUnlessAXKeyError({ + 'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:foo', + 'value.bar': '', + }) def test_doubleSingleton(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': '', - 'type.bar': 'urn:bar', - 'value.bar': '', - }, {'urn:foo': [], 'urn:bar': []}) + self.failUnlessAXValues({ + 'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:bar', + 'value.bar': '', + }, {'urn:foo': [], + 'urn:bar': []}) def test_singletonValue(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': 'Westfall', - }, {'urn:foo': ['Westfall']}) + self.failUnlessAXValues({ + 'type.foo': 'urn:foo', + 'value.foo': 'Westfall', + }, {'urn:foo': ['Westfall']}) class FetchRequestTest(unittest.TestCase): @@ -232,7 +231,7 @@ class FetchRequestTest(unittest.TestCase): def test_getExtensionArgs_empty(self): expected_args = { 'mode': 'fetch_request', - } + } self.assertEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_noAlias(self): @@ -249,30 +248,28 @@ class FetchRequestTest(unittest.TestCase): self.failUnlessExtensionArgs({ 'type.' + alias: attr.type_uri, 'if_available': alias, - }) + }) def test_getExtensionArgs_alias_if_available(self): attr = ax.AttrInfo( type_uri='type://of.transportation', - alias='transport', - ) + alias='transport', ) self.msg.add(attr) self.failUnlessExtensionArgs({ 'type.' + attr.alias: attr.type_uri, 'if_available': attr.alias, - }) + }) def test_getExtensionArgs_alias_req(self): attr = ax.AttrInfo( type_uri='type://of.transportation', alias='transport', - required=True, - ) + required=True, ) self.msg.add(attr) self.failUnlessExtensionArgs({ 'type.' + attr.alias: attr.type_uri, 'required': attr.alias, - }) + }) def failUnlessExtensionArgs(self, expected_args): """Make sure that getExtensionArgs has the expected result @@ -294,16 +291,16 @@ class FetchRequestTest(unittest.TestCase): extension_args = { 'mode': 'fetch_request', 'type.' + self.alias_a: self.type_a, - } - self.assertRaises(ValueError, - self.msg.parseExtensionArgs, extension_args) + } + self.assertRaises(ValueError, self.msg.parseExtensionArgs, + extension_args) def test_parseExtensionArgs(self): extension_args = { 'mode': 'fetch_request', 'type.' + self.alias_a: self.type_a, 'if_available': self.alias_a - } + } self.msg.parseExtensionArgs(extension_args) self.assertTrue(self.type_a in self.msg) self.assertEqual([self.type_a], list(self.msg)) @@ -319,7 +316,7 @@ class FetchRequestTest(unittest.TestCase): 'mode': 'fetch_request', 'type.' + self.alias_a: self.type_a, 'if_available': self.alias_a - } + } self.msg.parseExtensionArgs(extension_args) self.assertEqual(extension_args, self.msg.getExtensionArgs()) self.assertFalse(self.msg.requested_attributes[self.type_a].required) @@ -330,7 +327,7 @@ class FetchRequestTest(unittest.TestCase): 'type.' + self.alias_a: self.type_a, 'count.' + self.alias_a: '2', 'required': self.alias_a - } + } self.msg.parseExtensionArgs(extension_args) self.assertEqual(extension_args, self.msg.getExtensionArgs()) self.assertTrue(self.msg.requested_attributes[self.type_a].required) @@ -341,62 +338,83 @@ class FetchRequestTest(unittest.TestCase): 'type.' + self.alias_a: self.type_a, 'count.' + self.alias_a: '1', 'if_available': self.alias_a, - } + } extension_args_norm = { 'mode': 'fetch_request', 'type.' + self.alias_a: self.type_a, 'if_available': self.alias_a, - } + } self.msg.parseExtensionArgs(extension_args) self.assertEqual(extension_args_norm, self.msg.getExtensionArgs()) def test_openidNoRealm(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.ax': ax.AXMessage.ns_uri, - 'ax.update_url': 'http://different.site/path', - 'ax.mode': 'fetch_request', - }) - self.assertRaises(ax.AXError, - ax.FetchRequest.fromOpenIDRequest, - DummyRequest(openid_req_msg)) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'ns.ax': + ax.AXMessage.ns_uri, + 'ax.update_url': + 'http://different.site/path', + 'ax.mode': + 'fetch_request', + }) + self.assertRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, + DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationError(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'realm': 'http://example.com/realm', - 'ns.ax': ax.AXMessage.ns_uri, - 'ax.update_url': 'http://different.site/path', - 'ax.mode': 'fetch_request', - }) - - self.assertRaises(ax.AXError, - ax.FetchRequest.fromOpenIDRequest, - DummyRequest(openid_req_msg)) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'realm': + 'http://example.com/realm', + 'ns.ax': + ax.AXMessage.ns_uri, + 'ax.update_url': + 'http://different.site/path', + 'ax.mode': + 'fetch_request', + }) + + self.assertRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, + DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationSuccess(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'realm': 'http://example.com/realm', - 'ns.ax': ax.AXMessage.ns_uri, - 'ax.update_url': 'http://example.com/realm/update_path', - 'ax.mode': 'fetch_request', - }) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'realm': + 'http://example.com/realm', + 'ns.ax': + ax.AXMessage.ns_uri, + 'ax.update_url': + 'http://example.com/realm/update_path', + 'ax.mode': + 'fetch_request', + }) fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationSuccessReturnTo(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'return_to': 'http://example.com/realm', - 'ns.ax': ax.AXMessage.ns_uri, - 'ax.update_url': 'http://example.com/realm/update_path', - 'ax.mode': 'fetch_request', - }) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'return_to': + 'http://example.com/realm', + 'ns.ax': + ax.AXMessage.ns_uri, + 'ax.update_url': + 'http://example.com/realm/update_path', + 'ax.mode': + 'fetch_request', + }) fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) @@ -405,21 +423,26 @@ class FetchRequestTest(unittest.TestCase): openid_req_msg = Message.fromOpenIDArgs({ 'mode': 'checkid_setup', 'ns': OPENID2_NS, - }) + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) - self.assertTrue(r is None, "%s is not None" % (r,)) + self.assertTrue(r is None, "%s is not None" % (r, )) def test_fromOpenIDRequestWithoutData(self): """return something for SuccessResponse with AX paramaters, even if it is the empty set.""" openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'realm': 'http://example.com/realm', - 'ns': OPENID2_NS, - 'ns.ax': ax.AXMessage.ns_uri, - 'ax.mode': 'fetch_request', - }) + 'mode': + 'checkid_setup', + 'realm': + 'http://example.com/realm', + 'ns': + OPENID2_NS, + 'ns.ax': + ax.AXMessage.ns_uri, + 'ax.mode': + 'fetch_request', + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) self.assertTrue(r is not None) @@ -440,13 +463,13 @@ class FetchResponseTest(unittest.TestCase): def test_getExtensionArgs_empty(self): expected_args = { 'mode': 'fetch_response', - } + } self.assertEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_empty_request(self): expected_args = { 'mode': 'fetch_response', - } + } req = ax.FetchRequest() msg = ax.FetchResponse(request=req) self.assertEqual(expected_args, msg.getExtensionArgs()) @@ -457,9 +480,9 @@ class FetchResponseTest(unittest.TestCase): expected_args = { 'mode': 'fetch_response', - 'type.%s' % (alias,): uri, - 'count.%s' % (alias,): '0' - } + 'type.%s' % (alias, ): uri, + 'count.%s' % (alias, ): '0' + } req = ax.FetchRequest() req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -472,9 +495,9 @@ class FetchResponseTest(unittest.TestCase): expected_args = { 'mode': 'fetch_response', 'update_url': self.request_update_url, - 'type.%s' % (alias,): uri, - 'count.%s' % (alias,): '0' - } + 'type.%s' % (alias, ): uri, + 'count.%s' % (alias, ): '0' + } req = ax.FetchRequest(update_url=self.request_update_url) req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -486,7 +509,7 @@ class FetchResponseTest(unittest.TestCase): 'type.' + self.alias_a: self.type_a, 'value.' + self.alias_a + '.1': self.value_a, 'count.' + self.alias_a: '1' - } + } req = ax.FetchRequest() req.add(ax.AttrInfo(self.type_a, alias=self.alias_a)) msg = ax.FetchResponse(request=req) @@ -519,7 +542,7 @@ class FetchResponseTest(unittest.TestCase): args = { 'mode': 'id_res', 'ns': OPENID2_NS, - } + } sf = ['openid.' + i for i in list(args.keys())] msg = Message.fromOpenIDArgs(args) @@ -528,7 +551,7 @@ class FetchResponseTest(unittest.TestCase): oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf) r = ax.FetchResponse.fromSuccessResponse(oreq) - self.assertTrue(r is None, "%s is not None" % (r,)) + self.assertTrue(r is None, "%s is not None" % (r, )) def test_fromSuccessResponseWithoutData(self): """return something for SuccessResponse with AX paramaters, @@ -538,7 +561,7 @@ class FetchResponseTest(unittest.TestCase): 'ns': OPENID2_NS, 'ns.ax': ax.AXMessage.ns_uri, 'ax.mode': 'fetch_response', - } + } sf = ['openid.' + i for i in list(args.keys())] msg = Message.fromOpenIDArgs(args) @@ -562,7 +585,7 @@ class FetchResponseTest(unittest.TestCase): 'ax.type.' + name: uri, 'ax.count.' + name: '1', 'ax.value.%s.1' % name: value, - } + } sf = ['openid.' + i for i in list(args.keys())] msg = Message.fromOpenIDArgs(args) @@ -589,7 +612,7 @@ class StoreRequestTest(unittest.TestCase): args = self.msg.getExtensionArgs() expected_args = { 'mode': 'store_request', - } + } self.assertEqual(expected_args, args) def test_getExtensionArgs_nonempty(self): @@ -602,9 +625,9 @@ class StoreRequestTest(unittest.TestCase): 'mode': 'store_request', 'type.' + self.alias_a: self.type_a, 'count.' + self.alias_a: '2', - 'value.%s.1' % (self.alias_a,): 'foo', - 'value.%s.2' % (self.alias_a,): 'bar', - } + 'value.%s.1' % (self.alias_a, ): 'foo', + 'value.%s.2' % (self.alias_a, ): 'bar', + } self.assertEqual(expected_args, args) @@ -613,20 +636,24 @@ class StoreResponseTest(unittest.TestCase): msg = ax.StoreResponse() self.assertTrue(msg.succeeded()) self.assertFalse(msg.error_message) - self.assertEqual({'mode': 'store_response_success'}, - msg.getExtensionArgs()) + self.assertEqual({ + 'mode': 'store_response_success' + }, msg.getExtensionArgs()) def test_fail_nomsg(self): msg = ax.StoreResponse(False) self.assertFalse(msg.succeeded()) self.assertFalse(msg.error_message) - self.assertEqual({'mode': 'store_response_failure'}, - msg.getExtensionArgs()) + self.assertEqual({ + 'mode': 'store_response_failure' + }, msg.getExtensionArgs()) def test_fail_msg(self): reason = 'no reason, really' msg = ax.StoreResponse(False, reason) self.assertFalse(msg.succeeded()) self.assertEqual(reason, msg.error_message) - self.assertEqual({'mode': 'store_response_failure', - 'error': reason}, msg.getExtensionArgs()) + self.assertEqual({ + 'mode': 'store_response_failure', + 'error': reason + }, msg.getExtensionArgs()) diff --git a/openid/test/test_codecutil.py b/openid/test/test_codecutil.py index bb19d9e063a3f5a821c62c11c57f5d171c7d2214..30995a23760e760d3f04966c179c547dce66e805 100644 --- a/openid/test/test_codecutil.py +++ b/openid/test/test_codecutil.py @@ -5,16 +5,14 @@ from openid import codecutil # registers encoder class EncoderTest(unittest.TestCase): def test_handler_registered(self): - self.assertEqual( - "foo".encode('ascii', errors='oid_percent_escape'), - b"foo") + self.assertEqual("foo".encode('ascii', errors='oid_percent_escape'), + b"foo") def test_encoding(self): s = 'l\xa1m\U00101010n' expected = b'l%C2%A1m%F4%81%80%90n' self.assertEqual( - s.encode('ascii', errors='oid_percent_escape'), - expected) + s.encode('ascii', errors='oid_percent_escape'), expected) if __name__ == '__main__': diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index c3df2810c5370b1a49e5848211e2d6bb36445aa7..302c92b9613a43addc5d01ae46aefba3f63a5df6 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -31,7 +31,7 @@ from .support import CatchLogs assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), - ] +] def mkSuccess(endpoint, q): @@ -59,7 +59,7 @@ def associate(qs, assoc_secret, assoc_handle): 'assoc_type': 'HMAC-SHA1', 'assoc_handle': assoc_handle, 'expires_in': '600', - } + } if q.get('openid.session_type') == 'DH-SHA1': assert len(q) == 6 or len(q) == 4 @@ -115,8 +115,8 @@ class TestFetcher(object): pass # fall through else: assert body.find('DH-SHA1') != -1 - response = associate( - body, self.assoc_secret, self.assoc_handle) + response = associate(body, self.assoc_secret, + self.assoc_handle) self.num_assocs += 1 return self.response(url, 200, response) @@ -227,6 +227,7 @@ def _test_success(server_url, user_url, delegate_url, links, immediate=False): run() assert fetcher.num_assocs == 2 + import unittest http_server_url = b'http://server.example.com/' @@ -242,30 +243,30 @@ class TestSuccess(unittest.TestCase, CatchLogs): def setUp(self): CatchLogs.setUp(self) self.links = '<link rel="openid.server" href="%s" />' % ( - self.server_url,) + self.server_url, ) self.delegate_links = ('<link rel="openid.server" href="%s" />' '<link rel="openid.delegate" href="%s" />') % ( - self.server_url, self.delegate_url) + self.server_url, self.delegate_url) def tearDown(self): CatchLogs.tearDown(self) def test_nodelegate(self): - _test_success(self.server_url, self.user_url, - self.user_url, self.links) + _test_success(self.server_url, self.user_url, self.user_url, + self.links) def test_nodelegateImmediate(self): - _test_success(self.server_url, self.user_url, - self.user_url, self.links, True) + _test_success(self.server_url, self.user_url, self.user_url, + self.links, True) def test_delegate(self): - _test_success(self.server_url, self.user_url, - self.delegate_url, self.delegate_links) + _test_success(self.server_url, self.user_url, self.delegate_url, + self.delegate_links) def test_delegateImmediate(self): - _test_success(self.server_url, self.user_url, - self.delegate_url, self.delegate_links, True) + _test_success(self.server_url, self.user_url, self.delegate_url, + self.delegate_links, True) class TestSuccessHTTPS(TestSuccess): @@ -302,18 +303,22 @@ class TestIdRes(unittest.TestCase, CatchLogs): def disableDiscoveryVerification(self): """Set the discovery verification to a no-op for test cases in which we don't care.""" + def dummyVerifyDiscover(_, endpoint): return endpoint + self.consumer._verifyDiscoveryResults = dummyVerifyDiscover def disableReturnToChecking(self): def checkReturnTo(unused1, unused2): return True + self.consumer._checkReturnTo = checkReturnTo complete = self.consumer.complete def callCompleteWithoutReturnTo(message, endpoint): return complete(message, endpoint, None) + self.consumer.complete = callCompleteWithoutReturnTo @@ -325,13 +330,19 @@ class TestIdResCheckSignature(TestIdRes): self.store.storeAssociation(self.endpoint.server_url, self.assoc) self.message = Message.fromPostArgs({ - 'openid.mode': 'id_res', - 'openid.identity': '=example', - 'openid.sig': GOODSIG, - 'openid.assoc_handle': self.assoc.handle, - 'openid.signed': 'mode,identity,assoc_handle,signed', - 'frobboz': 'banzit', - }) + 'openid.mode': + 'id_res', + 'openid.identity': + '=example', + 'openid.sig': + GOODSIG, + 'openid.assoc_handle': + self.assoc.handle, + 'openid.signed': + 'mode,identity,assoc_handle,signed', + 'frobboz': + 'banzit', + }) def test_sign(self): # assoc_handle to assoc with good sig @@ -340,9 +351,8 @@ class TestIdResCheckSignature(TestIdRes): def test_signFailsWithBadSig(self): self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE') - self.assertRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, + self.message, self.endpoint.server_url) def test_stateless(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings @@ -357,9 +367,8 @@ class TestIdResCheckSignature(TestIdRes): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False - self.assertRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, + self.message, self.endpoint.server_url) def test_stateless_noStore(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings @@ -376,9 +385,8 @@ class TestIdResCheckSignature(TestIdRes): self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False self.consumer.store = None - self.assertRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, + self.message, self.endpoint.server_url) class TestQueryFormat(TestIdRes): @@ -393,7 +401,7 @@ class TestQueryFormat(TestIdRes): except TypeError as err: self.assertTrue(str(err).find('values') != -1, err) else: - self.fail("expected TypeError, got this instead: %s" % (r,)) + self.fail("expected TypeError, got this instead: %s" % (r, )) class TestComplete(TestIdRes): @@ -431,9 +439,10 @@ class TestComplete(TestIdRes): def test_error(self): msg = 'an error message' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - }) + message = Message.fromPostArgs({ + 'openid.mode': 'error', + 'openid.error': msg, + }) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.assertEqual(r.status, FAILURE) @@ -443,10 +452,11 @@ class TestComplete(TestIdRes): def test_errorWithNoOptionalKeys(self): msg = 'an error message' contact = 'some contact info here' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - 'openid.contact': contact, - }) + message = Message.fromPostArgs({ + 'openid.mode': 'error', + 'openid.error': msg, + 'openid.contact': contact, + }) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.assertEqual(r.status, FAILURE) @@ -459,10 +469,13 @@ class TestComplete(TestIdRes): msg = 'an error message' contact = 'me' reference = 'support ticket' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, 'openid.reference': reference, - 'openid.contact': contact, 'openid.ns': OPENID2_NS, - }) + message = Message.fromPostArgs({ + 'openid.mode': 'error', + 'openid.error': msg, + 'openid.reference': reference, + 'openid.contact': contact, + 'openid.ns': OPENID2_NS, + }) r = self.consumer.complete(message, self.endpoint, None) self.assertEqual(r.status, FAILURE) self.assertTrue(r.identity_url == self.endpoint.claimed_id) @@ -481,8 +494,8 @@ class TestComplete(TestIdRes): # is supposed to test for. status in FAILURE, but it's because # *check_auth* failed, not because it's missing an arg, exactly. message = Message.fromPostArgs({'openid.mode': 'id_res'}) - self.assertRaises(ProtocolError, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(ProtocolError, self.consumer._doIdRes, message, + self.endpoint, None) def test_idResURLMismatch(self): class VerifiedError(Exception): @@ -494,19 +507,24 @@ class TestComplete(TestIdRes): self.consumer._discoverAndVerify = discoverAndVerify self.disableReturnToChecking() - message = Message.fromPostArgs( - {'openid.mode': 'id_res', - 'openid.return_to': 'return_to (just anything)', - 'openid.identity': 'something wrong (not self.consumer_id)', - 'openid.assoc_handle': 'does not matter', - 'openid.sig': GOODSIG, - 'openid.signed': 'identity,return_to', - }) + message = Message.fromPostArgs({ + 'openid.mode': + 'id_res', + 'openid.return_to': + 'return_to (just anything)', + 'openid.identity': + 'something wrong (not self.consumer_id)', + 'openid.assoc_handle': + 'does not matter', + 'openid.sig': + GOODSIG, + 'openid.signed': + 'identity,return_to', + }) self.consumer.store = GoodAssocStore() - self.assertRaises(VerifiedError, - self.consumer.complete, - message, self.endpoint) + self.assertRaises(VerifiedError, self.consumer.complete, message, + self.endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -521,18 +539,28 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): claimed_id = 'bogus.claimed' - self.message = Message.fromOpenIDArgs( - {'mode': 'id_res', - 'return_to': 'return_to (just anything)', - 'identity': claimed_id, - 'assoc_handle': 'does not matter', - 'sig': GOODSIG, - 'response_nonce': mkNonce(), - 'signed': 'identity,return_to,response_nonce,assoc_handle,claimed_id,op_endpoint', - 'claimed_id': claimed_id, - 'op_endpoint': self.server_url, - 'ns': OPENID2_NS, - }) + self.message = Message.fromOpenIDArgs({ + 'mode': + 'id_res', + 'return_to': + 'return_to (just anything)', + 'identity': + claimed_id, + 'assoc_handle': + 'does not matter', + 'sig': + GOODSIG, + 'response_nonce': + mkNonce(), + 'signed': + 'identity,return_to,response_nonce,assoc_handle,claimed_id,op_endpoint', + 'claimed_id': + claimed_id, + 'op_endpoint': + self.server_url, + 'ns': + OPENID2_NS, + }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url @@ -555,43 +583,38 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): self.message.delArg(OPENID_NS, 'claimed_id') self.endpoint.claimed_id = None self.message.setArg( - OPENID_NS, - 'signed', + OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,op_endpoint') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) def test_idResMissingIdentitySig(self): - self.message.setArg( - OPENID_NS, 'signed', - 'return_to,response_nonce,assoc_handle,claimed_id') + self.message.setArg(OPENID_NS, 'signed', + 'return_to,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.assertEqual(r.status, FAILURE) def test_idResMissingReturnToSig(self): - self.message.setArg( - OPENID_NS, 'signed', - 'identity,response_nonce,assoc_handle,claimed_id') + self.message.setArg(OPENID_NS, 'signed', + 'identity,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.assertEqual(r.status, FAILURE) def test_idResMissingAssocHandleSig(self): - self.message.setArg( - OPENID_NS, 'signed', - 'identity,response_nonce,return_to,claimed_id') + self.message.setArg(OPENID_NS, 'signed', + 'identity,response_nonce,return_to,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.assertEqual(r.status, FAILURE) def test_idResMissingClaimedIDSig(self): - self.message.setArg( - OPENID_NS, 'signed', - 'identity,response_nonce,return_to,assoc_handle') + self.message.setArg(OPENID_NS, 'signed', + 'identity,response_nonce,return_to,assoc_handle') r = self.consumer.complete(self.message, self.endpoint, None) self.assertEqual(r.status, FAILURE) def failUnlessSuccess(self, response): if response.status != SUCCESS: - self.fail("Non-successful response: %s" % (response,)) + self.fail("Non-successful response: %s" % (response, )) class TestCheckAuthResponse(TestIdRes, CatchLogs): @@ -605,8 +628,8 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): def _createAssoc(self): issued = time.time() lifetime = 1000 - assoc = association.Association( - 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', issued, lifetime, + 'HMAC-SHA1') store = self.consumer.store store.storeAssociation(self.server_url, assoc) assoc2 = store.getAssociation(self.server_url) @@ -643,7 +666,7 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): response = Message.fromOpenIDArgs({ 'is_valid': 'false', 'invalidate_handle': 'handle', - }) + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertFalse(r) self.assertTrue( @@ -654,19 +677,17 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): response = Message.fromOpenIDArgs({ 'is_valid': 'true', 'invalidate_handle': 'missing', - }) + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) - self.failUnlessLogMatches( - 'Received "invalidate_handle"' - ) + self.failUnlessLogMatches('Received "invalidate_handle"') def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ 'is_valid': 'true', 'invalidate_handle': 'missing', - }) + }) self.consumer.store = None r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) @@ -687,7 +708,7 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): response = Message.fromOpenIDArgs({ 'is_valid': 'true', 'invalidate_handle': 'handle', - }) + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) self.assertTrue( @@ -709,7 +730,7 @@ class TestSetupNeeded(TestIdRes): message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, - }) + }) self.assertTrue(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -720,7 +741,7 @@ class TestSetupNeeded(TestIdRes): 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, 'openid.identity': 'bogus', - }) + }) self.assertTrue(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -737,7 +758,7 @@ class TestSetupNeeded(TestIdRes): message = Message.fromOpenIDArgs({ 'mode': 'setup_needed', 'ns': OPENID2_NS, - }) + }) self.assertTrue(message.isOpenID2()) response = self.consumer.complete(message, None, None) self.assertEqual('setup_needed', response.status) @@ -746,7 +767,7 @@ class TestSetupNeeded(TestIdRes): def test_setupNeededDoesntWorkForOpenID1(self): message = Message.fromOpenIDArgs({ 'mode': 'setup_needed', - }) + }) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) @@ -760,7 +781,7 @@ class TestSetupNeeded(TestIdRes): 'mode': 'id_res', 'game': 'puerto_rico', 'ns': OPENID2_NS, - }) + }) self.assertTrue(message.isOpenID2()) # No SetupNeededError raised @@ -776,38 +797,46 @@ class IdResCheckForFieldsTest(TestIdRes): message = Message.fromOpenIDArgs(openid_args) message.setArg(OPENID_NS, 'signed', ','.join(signed_list)) self.consumer._idResCheckForFields(message) + return test - test_openid1Success = mkSuccessTest( - {'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'identity': 'someone', - }, - ['return_to', 'identity']) - - test_openid2Success = mkSuccessTest( - {'ns': OPENID2_NS, - 'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'op_endpoint': 'my favourite server', - 'response_nonce': 'use only once', - }, - ['return_to', 'response_nonce', 'assoc_handle', 'op_endpoint']) - - test_openid2Success_identifiers = mkSuccessTest( - {'ns': OPENID2_NS, - 'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'claimed_id': 'i claim to be me', - 'identity': 'my server knows me as me', - 'op_endpoint': 'my favourite server', - 'response_nonce': 'use only once', - }, - ['return_to', 'response_nonce', 'identity', - 'claimed_id', 'assoc_handle', 'op_endpoint']) + test_openid1Success = mkSuccessTest({ + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + }, ['return_to', 'identity']) + + test_openid2Success = mkSuccessTest({ + 'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'op_endpoint': 'my favourite server', + 'response_nonce': 'use only once', + }, ['return_to', 'response_nonce', 'assoc_handle', 'op_endpoint']) + + test_openid2Success_identifiers = mkSuccessTest({ + 'ns': + OPENID2_NS, + 'return_to': + 'return', + 'assoc_handle': + 'assoc handle', + 'sig': + 'a signature', + 'claimed_id': + 'i claim to be me', + 'identity': + 'my server knows me as me', + 'op_endpoint': + 'my favourite server', + 'response_nonce': + 'use only once', + }, [ + 'return_to', 'response_nonce', 'identity', 'claimed_id', + 'assoc_handle', 'op_endpoint' + ]) def mkMissingFieldTest(openid_args): def test(self): @@ -818,6 +847,7 @@ class IdResCheckForFieldsTest(TestIdRes): self.assertTrue(str(why).startswith('Missing required')) else: self.fail('Expected an error, but none occurred') + return test def mkMissingSignedTest(openid_args): @@ -829,45 +859,66 @@ class IdResCheckForFieldsTest(TestIdRes): self.assertTrue(str(why).endswith('not signed')) else: self.fail('Expected an error, but none occurred') + return test - test_openid1Missing_returnToSig = mkMissingSignedTest( - {'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'identity': 'someone', - 'signed': 'identity', - }) - - test_openid1Missing_identitySig = mkMissingSignedTest( - {'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'identity': 'someone', - 'signed': 'return_to' - }) - - test_openid2Missing_opEndpointSig = mkMissingSignedTest( - {'ns': OPENID2_NS, - 'return_to': 'return', - 'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'identity': 'someone', - 'op_endpoint': 'the endpoint', - 'signed': 'return_to,identity,assoc_handle' - }) - - test_openid1MissingReturnTo = mkMissingFieldTest( - {'assoc_handle': 'assoc handle', - 'sig': 'a signature', - 'identity': 'someone', - }) - - test_openid1MissingAssocHandle = mkMissingFieldTest( - {'return_to': 'return', - 'sig': 'a signature', - 'identity': 'someone', - }) + test_openid1Missing_returnToSig = mkMissingSignedTest({ + 'return_to': + 'return', + 'assoc_handle': + 'assoc handle', + 'sig': + 'a signature', + 'identity': + 'someone', + 'signed': + 'identity', + }) + + test_openid1Missing_identitySig = mkMissingSignedTest({ + 'return_to': + 'return', + 'assoc_handle': + 'assoc handle', + 'sig': + 'a signature', + 'identity': + 'someone', + 'signed': + 'return_to' + }) + + test_openid2Missing_opEndpointSig = mkMissingSignedTest({ + 'ns': + OPENID2_NS, + 'return_to': + 'return', + 'assoc_handle': + 'assoc handle', + 'sig': + 'a signature', + 'identity': + 'someone', + 'op_endpoint': + 'the endpoint', + 'signed': + 'return_to,identity,assoc_handle' + }) + + test_openid1MissingReturnTo = mkMissingFieldTest({ + 'assoc_handle': + 'assoc handle', + 'sig': + 'a signature', + 'identity': + 'someone', + }) + + test_openid1MissingAssocHandle = mkMissingFieldTest({ + 'return_to': 'return', + 'sig': 'a signature', + 'identity': 'someone', + }) # XXX: I could go on... @@ -888,7 +939,7 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): def test_openid1Success(self): """use consumer-generated nonce""" nonce_value = mkNonce() - self.return_to = 'http://rt.unittest/?nonce=%s' % (nonce_value,) + self.return_to = 'http://rt.unittest/?nonce=%s' % (nonce_value, ) self.response = Message.fromOpenIDArgs({'return_to': self.return_to}) self.response.setArg(BARE_NS, 'nonce', nonce_value) self.consumer._idResCheckNonce(self.response, self.endpoint) @@ -903,28 +954,36 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): def test_consumerNonceOpenID2(self): """OpenID 2 does not use consumer-generated nonce""" - self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) - self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to, 'ns': OPENID2_NS}) + self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(), ) + self.response = Message.fromOpenIDArgs({ + 'return_to': self.return_to, + 'ns': OPENID2_NS + }) self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonce(self): """use server-generated nonce""" - self.response = Message.fromOpenIDArgs( - {'ns': OPENID2_NS, 'response_nonce': mkNonce()}) + self.response = Message.fromOpenIDArgs({ + 'ns': OPENID2_NS, + 'response_nonce': mkNonce() + }) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" - self.response = Message.fromOpenIDArgs( - {'ns': OPENID1_NS, - 'return_to': 'http://return.to/', - 'response_nonce': mkNonce()}) + self.response = Message.fromOpenIDArgs({ + 'ns': + OPENID1_NS, + 'return_to': + 'http://return.to/', + 'response_nonce': + mkNonce() + }) self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.response, self.endpoint) self.failUnlessLogEmpty() def test_badNonce(self): @@ -941,37 +1000,37 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): nonce = mkNonce() stamp, salt = splitNonce(nonce) self.store.useNonce(self.server_url, stamp, salt) - self.response = Message.fromOpenIDArgs( - {'response_nonce': nonce, - 'ns': OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({ + 'response_nonce': nonce, + 'ns': OPENID2_NS, + }) self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.response, self.endpoint) def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" self.consumer.store = None - self.response = Message.fromOpenIDArgs( - {'response_nonce': mkNonce(), - 'ns': OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({ + 'response_nonce': mkNonce(), + 'ns': OPENID2_NS, + }) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_tamperedNonce(self): """Malformed nonce""" - self.response = Message.fromOpenIDArgs( - {'ns': OPENID2_NS, - 'response_nonce': 'malformed'}) + self.response = Message.fromOpenIDArgs({ + 'ns': OPENID2_NS, + 'response_nonce': 'malformed' + }) self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.response, self.endpoint) def test_missingNonce(self): """no nonce parameter on the return_to""" - self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to}) + self.response = Message.fromOpenIDArgs({'return_to': self.return_to}) self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.response, self.endpoint) class CheckAuthDetectingConsumer(GenericConsumer): @@ -999,7 +1058,7 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() try: result = self.consumer._doIdRes(message, self.endpoint, None) @@ -1014,8 +1073,8 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): # handle that is in the message issued = time.time() lifetime = 1000 - assoc = association.Association( - 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', issued, lifetime, + 'HMAC-SHA1') self.store.storeAssociation(self.server_url, assoc) self.disableReturnToChecking() message = Message.fromPostArgs({ @@ -1024,13 +1083,13 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) try: result = self.consumer._doIdRes(message, self.endpoint, None) except CheckAuthHappened: pass else: - self.fail('_checkAuth did not happen. Result was: %r' % (result,)) + self.fail('_checkAuth did not happen. Result was: %r' % (result, )) def test_expiredAssoc(self): # Store an expired association for the server with the handle @@ -1038,8 +1097,8 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): issued = time.time() - 10 lifetime = 0 handle = 'handle' - assoc = association.Association( - handle, 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association(handle, 'secret', issued, lifetime, + 'HMAC-SHA1') self.assertTrue(assoc.expiresIn <= 0) self.store.storeAssociation(self.server_url, assoc) @@ -1049,10 +1108,10 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): 'openid.assoc_handle': handle, 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() - self.assertRaises(ProtocolError, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(ProtocolError, self.consumer._doIdRes, message, + self.endpoint, None) def test_newerAssoc(self): lifetime = 1000 @@ -1065,15 +1124,15 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): bad_issued = time.time() - 5 bad_handle = 'handle2' - bad_assoc = association.Association( - bad_handle, 'secret', bad_issued, lifetime, 'HMAC-SHA1') + bad_assoc = association.Association(bad_handle, 'secret', bad_issued, + lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, bad_assoc) query = { 'return_to': self.return_to, 'identity': self.server_id, 'assoc_handle': good_handle, - } + } message = Message.fromOpenIDArgs(query) message = good_assoc.signMessage(message) @@ -1110,7 +1169,7 @@ class TestReturnToArgs(unittest.TestCase): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) @@ -1119,29 +1178,26 @@ class TestReturnToArgs(unittest.TestCase): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. - self.assertRaises(ProtocolError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ProtocolError, self.consumer._verifyReturnToArgs, + query) def test_returnToMismatch(self): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', - } + } # fail, query has no key 'foo'. - self.assertRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) query['foo'] = 'baz' # fail, values for 'foo' do not match. - self.assertRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_noReturnTo(self): query = {'openid.mode': 'id_res'} - self.assertRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_completeBadReturnTo(self): """Test GenericConsumer.complete()'s handling of bad return_to @@ -1162,7 +1218,7 @@ class TestReturnToArgs(unittest.TestCase): # Query args differ "http://some.url/path?foo=bar2", "http://some.url/path?foo2=bar", - ] + ] m = Message(OPENID1_NS) m.setArg(OPENID_NS, 'mode', 'cancel') @@ -1181,15 +1237,18 @@ class TestReturnToArgs(unittest.TestCase): good_return_tos = [ (return_to, {}), - (return_to + "?another=arg", {(BARE_NS, 'another'): 'arg'}), - (return_to + "?another=arg#fragment", - {(BARE_NS, 'another'): 'arg'}), + (return_to + "?another=arg", { + (BARE_NS, 'another'): 'arg' + }), + (return_to + "?another=arg#fragment", { + (BARE_NS, 'another'): 'arg' + }), ("HTTP" + return_to[4:], {}), (return_to.replace('url', 'URL'), {}), ("http://some.url:80/path", {}), ("http://some.url/p%61th", {}), ("http://some.url/./path", {}), - ] + ] endpoint = None @@ -1231,7 +1290,7 @@ class BadArgCheckingConsumer(GenericConsumer): 'openid.mode': 'check_authentication', 'openid.signed': 'foo', 'openid.ns': OPENID1_NS - }, args + }, args return None @@ -1253,12 +1312,11 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): fetchers.setDefaultFetcher(self._orig_fetcher, wrap_exceptions=False) def test_error(self): - self.fetcher.response = HTTPResponse( - "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') - query = {'openid.signed': 'stuff', - 'openid.stuff': 'a value'} - r = self.consumer._checkAuth(Message.fromPostArgs(query), - http_server_url) + self.fetcher.response = HTTPResponse("http://some_url", 404, + {'Hea': 'der'}, 'blah:blah\n') + query = {'openid.signed': 'stuff', 'openid.stuff': 'a value'} + r = self.consumer._checkAuth( + Message.fromPostArgs(query), http_server_url) self.assertFalse(r) self.assertTrue(self.messages) @@ -1266,21 +1324,29 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): query = { 'openid.signed': 'foo', 'closid.foo': 'something', - } + } consumer = BadArgCheckingConsumer(self.store) consumer._checkAuth(Message.fromPostArgs(query), 'does://not.matter') def test_signedList(self): query = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'sig': 'rabbits', - 'identity': '=example', - 'assoc_handle': 'munchkins', - 'ns.sreg': 'urn:sreg', - 'sreg.email': 'bogus@example.com', - 'signed': 'identity,mode,ns.sreg,sreg.email', - 'foo': 'bar', - }) + 'mode': + 'id_res', + 'sig': + 'rabbits', + 'identity': + '=example', + 'assoc_handle': + 'munchkins', + 'ns.sreg': + 'urn:sreg', + 'sreg.email': + 'bogus@example.com', + 'signed': + 'identity,mode,ns.sreg,sreg.email', + 'foo': + 'bar', + }) args = self.consumer._createCheckAuthRequest(query) self.assertTrue(args.isOpenID1()) for signed_arg in query.getArg(OPENID_NS, 'signed').split(','): @@ -1288,20 +1354,34 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): def test_112(self): args = { - 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' + 'openid.assoc_handle': + 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': + 'http://binkley.lan/user/test01', + 'openid.identity': + 'http://test01.binkley.lan/', + 'openid.mode': + 'id_res', + 'openid.ns': + 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': + 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': + 'http://binkley.lan/server', + 'openid.pape.auth_policies': + 'none', + 'openid.pape.auth_time': + '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': + '0', + 'openid.response_nonce': + '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': + 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': + 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': + 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' } self.assertEqual(OPENID2_NS, args['openid.ns']) incoming = Message.fromPostArgs(args) @@ -1330,13 +1410,13 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): def test_error_404(self): """404 from a kv post raises HTTPFetchingError""" - self.fetcher.response = HTTPResponse( - "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') - self.assertRaises( - fetchers.HTTPFetchingError, - self.consumer._makeKVPost, - Message.fromPostArgs({'mode': 'associate'}), - "http://server_url") + self.fetcher.response = HTTPResponse("http://some_url", 404, + {'Hea': 'der'}, 'blah:blah\n') + self.assertRaises(fetchers.HTTPFetchingError, + self.consumer._makeKVPost, + Message.fromPostArgs({ + 'mode': 'associate' + }), "http://server_url") def test_error_exception_unwrapped(self): """Ensure that exceptions are bubbled through from fetchers @@ -1344,21 +1424,21 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): """ self.fetcher = ExceptionRaisingMockFetcher() fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False) - self.assertRaises(self.fetcher.MyException, - self.consumer._makeKVPost, - Message.fromPostArgs({'mode': 'associate'}), - "http://server_url") + self.assertRaises(self.fetcher.MyException, self.consumer._makeKVPost, + Message.fromPostArgs({ + 'mode': 'associate' + }), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() e.server_url = 'some://url' self.assertRaises(self.fetcher.MyException, - self.consumer._getAssociation, e) + self.consumer._getAssociation, e) - self.assertRaises(self.fetcher.MyException, - self.consumer._checkAuth, - Message.fromPostArgs({'openid.signed': ''}), - 'some://url') + self.assertRaises(self.fetcher.MyException, self.consumer._checkAuth, + Message.fromPostArgs({ + 'openid.signed': '' + }), 'some://url') def test_error_exception_wrapped(self): """Ensure that openid.fetchers.HTTPFetchingError is caught by @@ -1368,9 +1448,10 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): # This will wrap exceptions! fetchers.setDefaultFetcher(self.fetcher) self.assertRaises(fetchers.HTTPFetchingError, - self.consumer._makeKVPost, - Message.fromOpenIDArgs({'mode': 'associate'}), - "http://server_url") + self.consumer._makeKVPost, + Message.fromOpenIDArgs({ + 'mode': 'associate' + }), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() @@ -1394,7 +1475,7 @@ class TestSuccessResponse(unittest.TestCase): 'unittest.two': '2', 'sreg.nickname': 'j3h', 'return_to': 'return_to', - }) + }) utargs = resp.extensionResponse('urn:unittest', False) self.assertEqual(utargs, {'one': '1', 'two': '2'}) sregargs = resp.extensionResponse('urn:sreg', False) @@ -1410,11 +1491,11 @@ class TestSuccessResponse(unittest.TestCase): 'sreg.dob': 'yesterday', 'return_to': 'return_to', 'signed': 'sreg.nickname,unittest.one,sreg.dob', - } + } - signed_list = ['openid.sreg.nickname', - 'openid.unittest.one', - 'openid.sreg.dob'] + signed_list = [ + 'openid.sreg.nickname', 'openid.unittest.one', 'openid.sreg.dob' + ] # Don't use mkSuccess because it creates an all-inclusive # signed list. @@ -1423,10 +1504,7 @@ class TestSuccessResponse(unittest.TestCase): # All args in this NS are signed, so expect all. sregargs = resp.extensionResponse('urn:sreg', True) - self.assertEqual(sregargs, { - 'nickname': 'j3h', - 'dob': 'yesterday' - }) + self.assertEqual(sregargs, {'nickname': 'j3h', 'dob': 'yesterday'}) # Not all args in this NS are signed, so expect None when # asking for them. @@ -1443,14 +1521,12 @@ class TestSuccessResponse(unittest.TestCase): def test_displayIdentifierClaimedId(self): resp = mkSuccess(self.endpoint, {}) - self.assertEqual(resp.getDisplayIdentifier(), - resp.endpoint.claimed_id) + self.assertEqual(resp.getDisplayIdentifier(), resp.endpoint.claimed_id) def test_displayIdentifierOverride(self): self.endpoint.display_identifier = "http://input.url/" resp = mkSuccess(self.endpoint, {}) - self.assertEqual(resp.getDisplayIdentifier(), - "http://input.url/") + self.assertEqual(resp.getDisplayIdentifier(), "http://input.url/") class StubConsumer(object): @@ -1474,6 +1550,7 @@ class ConsumerTest(unittest.TestCase): Its GenericConsumer component is stubbed out with StubConsumer. """ + def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = self.identity_url = 'http://identity.url/' @@ -1481,19 +1558,18 @@ class ConsumerTest(unittest.TestCase): self.session = {} self.consumer = Consumer(self.session, self.store) self.consumer.consumer = StubConsumer() - self.discovery = Discovery(self.session, - self.identity_url, + self.discovery = Discovery(self.session, self.identity_url, self.consumer.session_key_prefix) def test_setAssociationPreference(self): self.consumer.setAssociationPreference([]) - self.assertTrue(isinstance(self.consumer.consumer.negotiator, - association.SessionNegotiator)) - self.assertEqual([], - self.consumer.consumer.negotiator.allowed_types) + self.assertTrue( + isinstance(self.consumer.consumer.negotiator, + association.SessionNegotiator)) + self.assertEqual([], self.consumer.consumer.negotiator.allowed_types) self.consumer.setAssociationPreference([('HMAC-SHA1', 'DH-SHA1')]) self.assertEqual([('HMAC-SHA1', 'DH-SHA1')], - self.consumer.consumer.negotiator.allowed_types) + self.consumer.consumer.negotiator.allowed_types) def withDummyDiscovery(self, callable, dummy_getNextService): class DummyDisco(object): @@ -1513,6 +1589,7 @@ class ConsumerTest(unittest.TestCase): def test_beginHTTPError(self): """Make sure that the discovery HTTP failure case behaves properly """ + def getNextService(self, ignored): raise HTTPFetchingError("Unit test") @@ -1652,9 +1729,8 @@ class ConsumerTest(unittest.TestCase): def test_completeSetupNeeded(self): setup_url = 'http://setup.url/' - resp = self._doRespDisco( - False, - SetupNeededResponse(self.endpoint, setup_url)) + resp = self._doRespDisco(False, + SetupNeededResponse(self.endpoint, setup_url)) self.assertTrue(resp.setup_url is setup_url) def test_successDifferentURL(self): @@ -1699,12 +1775,17 @@ class IDPDrivenTest(unittest.TestCase): def test_idpDrivenComplete(self): identifier = '=directed_identifier' message = Message.fromPostArgs({ - 'openid.identity': '=directed_identifier', - 'openid.return_to': 'x', - 'openid.assoc_handle': 'z', - 'openid.signed': 'identity,return_to', - 'openid.sig': GOODSIG, - }) + 'openid.identity': + '=directed_identifier', + 'openid.return_to': + 'x', + 'openid.assoc_handle': + 'z', + 'openid.signed': + 'identity,return_to', + 'openid.sig': + GOODSIG, + }) discovered_endpoint = OpenIDServiceEndpoint() discovered_endpoint.claimed_id = identifier @@ -1731,24 +1812,29 @@ class IDPDrivenTest(unittest.TestCase): def test_idpDrivenCompleteFraud(self): # crap with an identifier that doesn't match discovery info message = Message.fromPostArgs({ - 'openid.identity': '=directed_identifier', - 'openid.return_to': 'x', - 'openid.assoc_handle': 'z', - 'openid.signed': 'identity,return_to', - 'openid.sig': GOODSIG, - }) + 'openid.identity': + '=directed_identifier', + 'openid.return_to': + 'x', + 'openid.assoc_handle': + 'z', + 'openid.signed': + 'identity,return_to', + 'openid.sig': + GOODSIG, + }) def verifyDiscoveryResults(identifier, endpoint): raise DiscoveryFailure("PHREAK!", None) self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._checkReturnTo = lambda unused1, unused2: True - self.assertRaises(DiscoveryFailure, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(DiscoveryFailure, self.consumer._doIdRes, message, + self.endpoint, None) def failUnlessSuccess(self, response): if response.status != SUCCESS: - self.fail("Non-successful response: %s" % (response,)) + self.fail("Non-successful response: %s" % (response, )) class TestDiscoveryVerification(unittest.TestCase): @@ -1764,11 +1850,15 @@ class TestDiscoveryVerification(unittest.TestCase): self.server_url = "http://endpoint.unittest/" self.message = Message.fromPostArgs({ - 'openid.ns': OPENID2_NS, - 'openid.identity': self.identifier, - 'openid.claimed_id': self.identifier, - 'openid.op_endpoint': self.server_url, - }) + 'openid.ns': + OPENID2_NS, + 'openid.identity': + self.identifier, + 'openid.claimed_id': + self.identifier, + 'openid.op_endpoint': + self.server_url, + }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url @@ -1808,7 +1898,7 @@ class TestDiscoveryVerification(unittest.TestCase): # Should we make more ProtocolError subclasses? self.assertTrue(str(e), text) else: - self.fail("expected ProtocolError, %r returned." % (r,)) + self.fail("expected ProtocolError, %r returned." % (r, )) def test_foreignDelegate(self): text = "verify failed" @@ -1833,14 +1923,14 @@ class TestDiscoveryVerification(unittest.TestCase): except ProtocolError as e: self.assertEqual(str(e), text) else: - self.fail("Exepected ProtocolError, %r returned" % (r,)) + self.fail("Exepected ProtocolError, %r returned" % (r, )) def test_nothingDiscovered(self): # a set of no things. self.services = [] self.assertRaises(DiscoveryFailure, - self.consumer._verifyDiscoveryResults, - self.message, self.endpoint) + self.consumer._verifyDiscoveryResults, self.message, + self.endpoint) def discoveryFunc(self, identifier): return identifier, self.services @@ -1864,12 +1954,12 @@ class TestCreateAssociationRequest(unittest.TestCase): self.endpoint, self.assoc_type, session_type) self.assertTrue(isinstance(session, PlainTextConsumerSession)) - expected = Message.fromOpenIDArgs( - {'ns': OPENID2_NS, - 'session_type': session_type, - 'mode': 'associate', - 'assoc_type': self.assoc_type, - }) + expected = Message.fromOpenIDArgs({ + 'ns': OPENID2_NS, + 'session_type': session_type, + 'mode': 'associate', + 'assoc_type': self.assoc_type, + }) self.assertEqual(expected, args) @@ -1880,10 +1970,11 @@ class TestCreateAssociationRequest(unittest.TestCase): self.endpoint, self.assoc_type, session_type) self.assertTrue(isinstance(session, PlainTextConsumerSession)) - self.assertEqual(Message.fromOpenIDArgs({ - 'mode': 'associate', - 'assoc_type': self.assoc_type, - }), args) + self.assertEqual( + Message.fromOpenIDArgs({ + 'mode': 'associate', + 'assoc_type': self.assoc_type, + }), args) def test_dhSHA1Compatibility(self): # Set the consumer's session type to a fast session since we @@ -1934,8 +2025,7 @@ class _TestingDiffieHellmanResponseParameters(object): self.secret = cryptutil.randomString(self.session_cls.secret_size) self.enc_mac_key = oidutil.toBase64( - self.server_dh.xorSecret(self.consumer_dh.public, - self.secret, + self.server_dh.xorSecret(self.consumer_dh.public, self.secret, self.session_cls.hash_func)) self.consumer_session = self.session_cls(self.consumer_dh) @@ -1953,29 +2043,27 @@ class _TestingDiffieHellmanResponseParameters(object): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) self.assertRaises(KeyError, self.consumer_session.extractSecret, - self.msg) + self.msg) def testAbsentMacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.assertRaises(KeyError, self.consumer_session.extractSecret, - self.msg) + self.msg) def testInvalidBase64Public(self): self.msg.setArg(OPENID_NS, 'dh_server_public', 'n o t b a s e 6 4.') self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) - self.assertRaises(ValueError, - self.consumer_session.extractSecret, - self.msg) + self.assertRaises(ValueError, self.consumer_session.extractSecret, + self.msg) def testInvalidBase64MacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.msg.setArg(OPENID_NS, 'enc_mac_key', 'n o t base 64') - self.assertRaises(ValueError, - self.consumer_session.extractSecret, - self.msg) + self.assertRaises(ValueError, self.consumer_session.extractSecret, + self.msg) class TestOpenID1SHA1(_TestingDiffieHellmanResponseParameters, @@ -1989,6 +2077,7 @@ class TestOpenID2SHA1(_TestingDiffieHellmanResponseParameters, session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID2_NS + if cryptutil.SHA256_AVAILABLE: class TestOpenID2SHA256(_TestingDiffieHellmanResponseParameters, @@ -2005,6 +2094,7 @@ class TestNoStore(unittest.TestCase): def test_completeNoGetAssoc(self): """_getAssociation is never called when the store is None""" + def notCalled(unused): self.fail('This method was unexpectedly called') @@ -2033,10 +2123,9 @@ class TestConsumerAnonymous(unittest.TestCase): def bogusBegin(unused): return NonAnonymousAuthRequest() + consumer.consumer.begin = bogusBegin - self.assertRaises( - ProtocolError, - consumer.beginWithoutDiscovery, None) + self.assertRaises(ProtocolError, consumer.beginWithoutDiscovery, None) class TestDiscoverAndVerify(unittest.TestCase): @@ -2046,15 +2135,13 @@ class TestDiscoverAndVerify(unittest.TestCase): def dummyDiscover(unused_identifier): return self.discovery_result + self.consumer._discover = dummyDiscover self.to_match = OpenIDServiceEndpoint() def failUnlessDiscoveryFailure(self): - self.assertRaises( - DiscoveryFailure, - self.consumer._discoverAndVerify, - 'http://claimed-id.com/', - [self.to_match]) + self.assertRaises(DiscoveryFailure, self.consumer._discoverAndVerify, + 'http://claimed-id.com/', [self.to_match]) def test_noServices(self): """Discovery returning no results results in a @@ -2070,6 +2157,7 @@ class TestDiscoverAndVerify(unittest.TestCase): def raiseProtocolError(unused1, unused2): raise ProtocolError('unit test') + self.consumer._verifyDiscoverySingle = raiseProtocolError self.failUnlessDiscoveryFailure() @@ -2083,12 +2171,13 @@ class TestDiscoverAndVerify(unittest.TestCase): # Make verifying discovery return True for this endpoint def returnTrue(unused1, unused2): return True + self.consumer._verifyDiscoverySingle = returnTrue # Since _verifyDiscoverySingle returns True, we should get the # first endpoint that we passed in as a result. - result = self.consumer._discoverAndVerify( - 'http://claimed.id/', [self.to_match]) + result = self.consumer._discoverAndVerify('http://claimed.id/', + [self.to_match]) self.assertEqual(matching_endpoint, result) @@ -2114,7 +2203,7 @@ class TestAddExtension(unittest.TestCase): class TestKVPost(unittest.TestCase): def setUp(self): - self.server_url = 'http://unittest/%s' % (self.id(),) + self.server_url = 'http://unittest/%s' % (self.id(), ) def test_200(self): from openid.fetchers import HTTPResponse @@ -2135,16 +2224,15 @@ class TestKVPost(unittest.TestCase): self.assertEqual(e.error_text, 'bonk') self.assertEqual(e.error_code, '7') else: - self.fail("Expected ServerError, got return %r" % (r,)) + self.fail("Expected ServerError, got return %r" % (r, )) def test_500(self): # 500 as an example of any non-200, non-400 code. response = HTTPResponse() response.status = 500 response.body = "foo:bar\nbaz:quux\n" - self.assertRaises(fetchers.HTTPFetchingError, - _httpResponseToMessage, response, - self.server_url) + self.assertRaises(fetchers.HTTPFetchingError, _httpResponseToMessage, + response, self.server_url) if __name__ == '__main__': diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index bf109a1dbd45743f60e9f803f64864529f69df6a..558e3d0c939b1bd6e73225325626594188e77706 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -14,6 +14,7 @@ from openid.yadis import xrires from openid.yadis.xri import XRI from openid import message import openid.store.memstore + ### Tests for conditions that trigger DiscoveryFailure @@ -34,10 +35,14 @@ class TestDiscoveryFailure(datadriven.DataDrivenTestCase): [HTTPResponse('http://not.found/', 404)], [HTTPResponse('http://bad.request/', 400)], [HTTPResponse('http://server.error/', 500)], - [HTTPResponse('http://header.found/', 200, - headers={'x-xrds-location':'http://xrds.missing/'}), - HTTPResponse('http://xrds.missing/', 404)], - ] + [ + HTTPResponse( + 'http://header.found/', + 200, + headers={'x-xrds-location': 'http://xrds.missing/'}), + HTTPResponse('http://xrds.missing/', 404) + ], + ] def __init__(self, responses): self.url = responses[0].final_url @@ -96,7 +101,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): DidFetch(), ValueError(), RuntimeError(), - ] + ] # String exceptions are finally gone from Python 2.6. if sys.version_info[:2] < (2, 6): @@ -129,6 +134,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): ### Tests for openid.consumer.discover.discover + class TestNormalization(unittest.TestCase): def testAddingProtocol(self): f = ErrorRaisingFetcher(RuntimeError()) @@ -169,6 +175,7 @@ class DiscoveryMockFetcher(object): return HTTPResponse(final_url, status, {'content-type': ctype}, body) + # from twisted.trial import unittest as trialtest @@ -178,15 +185,15 @@ class BaseTestDiscovery(unittest.TestCase): documents = {} fetcherClass = DiscoveryMockFetcher - def _checkService(self, s, + def _checkService(self, + s, server_url, claimed_id=None, local_id=None, canonical_id=None, types=None, used_yadis=False, - display_identifier=None - ): + display_identifier=None): self.assertEqual(server_url, s.server_url) if types == ['2.0 OP']: self.assertFalse(claimed_id) @@ -197,7 +204,7 @@ class BaseTestDiscovery(unittest.TestCase): self.assertFalse(s.compatibilityMode()) self.assertTrue(s.isOPIdentifier()) self.assertEqual(s.preferredNamespace(), - discover.OPENID_2_0_MESSAGE_NS) + discover.OPENID_2_0_MESSAGE_NS) else: self.assertEqual(claimed_id, s.claimed_id) self.assertEqual(local_id, s.getLocalID()) @@ -206,14 +213,14 @@ class BaseTestDiscovery(unittest.TestCase): self.assertTrue(s.used_yadis, "Expected to use Yadis") else: self.assertFalse(s.used_yadis, - "Expected to use old-style discovery") + "Expected to use old-style discovery") openid_types = { '1.1': discover.OPENID_1_1_TYPE, '1.0': discover.OPENID_1_0_TYPE, '2.0': discover.OPENID_2_0_TYPE, '2.0 OP': discover.OPENID_IDP_2_0_TYPE, - } + } type_uris = [openid_types[t] for t in types] self.assertEqual(type_uris, s.type_uris) @@ -243,15 +250,18 @@ def readDataFile(filename): read in binary mode and the return value is a bytes object. """ module_directory = os.path.dirname(os.path.abspath(__file__)) - filename = os.path.join( - module_directory, 'data', 'test_discover', filename) + filename = os.path.join(module_directory, 'data', 'test_discover', + filename) with open(filename, 'rb') as f: contents = f.read() return contents class TestDiscovery(BaseTestDiscovery): - def _discover(self, content_type, data, expected_service_count, + def _discover(self, + content_type, + data, + expected_service_count, expected_id=None): if expected_id is None: expected_id = self.id_url @@ -263,8 +273,8 @@ class TestDiscovery(BaseTestDiscovery): return services def test_404(self): - self.assertRaises( - DiscoveryFailure, discover.discover, self.id_url + '/404') + self.assertRaises(DiscoveryFailure, discover.discover, + self.id_url + '/404') def test_unicode(self): """ @@ -280,24 +290,24 @@ class TestDiscovery(BaseTestDiscovery): Check page with unicode and HTML entities that can not be decoded but xrds document is found before it matters """ - self.documents[self.id_url + 'xrds'] = ( - 'application/xrds+xml', readDataFile('yadis_idp.xml')) + self.documents[self.id_url + 'xrds'] = ('application/xrds+xml', + readDataFile('yadis_idp.xml')) data = readDataFile('unicode3.html') self.assertRaises(UnicodeDecodeError, data.decode, 'utf-8') - self._discover(content_type='text/html;charset=utf-8', - data=data, expected_service_count=1) + self._discover( + content_type='text/html;charset=utf-8', + data=data, + expected_service_count=1) def test_noOpenID(self): - services = self._discover(content_type='text/plain', - data="junk", - expected_service_count=0) + services = self._discover( + content_type='text/plain', data="junk", expected_service_count=0) services = self._discover( content_type='text/html', data=readDataFile('openid_no_delegate.html'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], @@ -305,8 +315,7 @@ class TestDiscovery(BaseTestDiscovery): types=['1.1'], server_url="http://www.myopenid.com/server", claimed_id=self.id_url, - local_id=self.id_url, - ) + local_id=self.id_url, ) def test_html1(self): services = self._discover( @@ -321,8 +330,7 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_html1Fragment(self): """Ensure that the Claimed Identifier does not have a fragment @@ -345,15 +353,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=expected_id, local_id='http://smoker.myopenid.com/', - display_identifier=expected_id, - ) + display_identifier=expected_id, ) def test_html2(self): services = self._discover( content_type='text/html', data=readDataFile('openid2.html'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], @@ -362,15 +368,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_html1And2(self): services = self._discover( content_type='text/html', data=readDataFile('openid_1_and_2.html'), - expected_service_count=2, - ) + expected_service_count=2, ) for t, s in zip(['2.0', '1.1'], services): self._checkService( @@ -380,13 +384,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadisEmpty(self): - services = self._discover(content_type='application/xrds+xml', - data=readDataFile('yadis_0entries.xml'), - expected_service_count=0) + services = self._discover( + content_type='application/xrds+xml', + data=readDataFile('yadis_0entries.xml'), + expected_service_count=0) def test_htmlEmptyYadis(self): """HTML document has discovery information, but points to an @@ -395,9 +399,10 @@ class TestDiscovery(BaseTestDiscovery): self.documents[self.id_url + 'xrds'] = ( 'application/xrds+xml', readDataFile('yadis_0entries.xml')) - services = self._discover(content_type='text/html', - data=readDataFile('openid_and_yadis.html'), - expected_service_count=1) + services = self._discover( + content_type='text/html', + data=readDataFile('openid_and_yadis.html'), + expected_service_count=1) self._checkService( services[0], @@ -406,13 +411,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis1NoDelegate(self): - services = self._discover(content_type='application/xrds+xml', - data=readDataFile('yadis_no_delegate.xml'), - expected_service_count=1) + services = self._discover( + content_type='application/xrds+xml', + data=readDataFile('yadis_no_delegate.xml'), + expected_service_count=1) self._checkService( services[0], @@ -421,15 +426,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id=self.id_url, - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis2NoLocalID(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds_no_local_id.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], @@ -438,15 +441,13 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id=self.id_url, - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], @@ -455,53 +456,48 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis2OP(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('yadis_idp.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], used_yadis=True, types=['2.0 OP'], server_url="http://www.myopenid.com/server", - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis2OPDelegate(self): """The delegate tag isn't meaningful for OP entries.""" services = self._discover( content_type='application/xrds+xml', data=readDataFile('yadis_idp_delegate.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], used_yadis=True, types=['2.0 OP'], server_url="http://www.myopenid.com/server", - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis2BadLocalID(self): - self.assertRaises(DiscoveryFailure, self._discover, + self.assertRaises( + DiscoveryFailure, + self._discover, content_type='application/xrds+xml', data=readDataFile('yadis_2_bad_local_id.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) def test_yadis1And2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid_1_and_2_xrds.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) self._checkService( services[0], @@ -510,19 +506,18 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', - display_identifier=self.id_url, - ) + display_identifier=self.id_url, ) def test_yadis1And2BadLocalID(self): - self.assertRaises(DiscoveryFailure, self._discover, + self.assertRaises( + DiscoveryFailure, + self._discover, content_type='application/xrds+xml', data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), - expected_service_count=1, - ) + expected_service_count=1, ) class MockFetcherForXRIProxy(object): - def __init__(self, documents, proxy_url=xrires.DEFAULT_PROXY): self.documents = documents self.fetchlog = [] @@ -558,10 +553,12 @@ class MockFetcherForXRIProxy(object): class TestXRIDiscovery(BaseTestDiscovery): fetcherClass = MockFetcherForXRIProxy - documents = {'=smoker': ('application/xrds+xml', - readDataFile('yadis_2entries_delegate.xml')), - '=smoker*bad': ('application/xrds+xml', - readDataFile('yadis_another_delegate.xml'))} + documents = { + '=smoker': ('application/xrds+xml', + readDataFile('yadis_2entries_delegate.xml')), + '=smoker*bad': ('application/xrds+xml', + readDataFile('yadis_another_delegate.xml')) + } def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') @@ -574,8 +571,7 @@ class TestXRIDiscovery(BaseTestDiscovery): claimed_id=XRI("=!1000"), canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', - display_identifier='=smoker' - ) + display_identifier='=smoker') self._checkService( services[1], @@ -585,8 +581,7 @@ class TestXRIDiscovery(BaseTestDiscovery): claimed_id=XRI("=!1000"), canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', - display_identifier='=smoker' - ) + display_identifier='=smoker') def test_xri_normalize(self): user_xri, services = discover.discoverXRI('xri://=smoker') @@ -599,8 +594,7 @@ class TestXRIDiscovery(BaseTestDiscovery): claimed_id=XRI("=!1000"), canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', - display_identifier='=smoker' - ) + display_identifier='=smoker') self._checkService( services[1], @@ -610,8 +604,7 @@ class TestXRIDiscovery(BaseTestDiscovery): claimed_id=XRI("=!1000"), canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', - display_identifier='=smoker' - ) + display_identifier='=smoker') def test_xriNoCanonicalID(self): user_xri, services = discover.discoverXRI('=smoker*bad') @@ -629,20 +622,22 @@ class TestXRIDiscovery(BaseTestDiscovery): class TestXRIDiscoveryIDP(BaseTestDiscovery): fetcherClass = MockFetcherForXRIProxy - documents = {'=smoker': ('application/xrds+xml', - readDataFile('yadis_2entries_idp.xml'))} + documents = { + '=smoker': ('application/xrds+xml', + readDataFile('yadis_2entries_idp.xml')) + } def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') self.assertTrue(services, "Expected services, got zero") self.assertEqual(services[0].server_url, - "http://www.livejournal.com/openid/server.bml") + "http://www.livejournal.com/openid/server.bml") class TestPreferredNamespace(datadriven.DataDrivenTestCase): def __init__(self, expected_ns, type_uris): - datadriven.DataDrivenTestCase.__init__( - self, 'Expecting %s from %s' % (expected_ns, type_uris)) + datadriven.DataDrivenTestCase.__init__(self, 'Expecting %s from %s' % + (expected_ns, type_uris)) self.expected_ns = expected_ns self.type_uris = type_uris @@ -659,11 +654,11 @@ class TestPreferredNamespace(datadriven.DataDrivenTestCase): (message.OPENID1_NS, [discover.OPENID_1_1_TYPE]), (message.OPENID2_NS, [discover.OPENID_2_0_TYPE]), (message.OPENID2_NS, [discover.OPENID_IDP_2_0_TYPE]), - (message.OPENID2_NS, [discover.OPENID_2_0_TYPE, - discover.OPENID_1_0_TYPE]), - (message.OPENID2_NS, [discover.OPENID_1_0_TYPE, - discover.OPENID_2_0_TYPE]), - ] + (message.OPENID2_NS, + [discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE]), + (message.OPENID2_NS, + [discover.OPENID_1_0_TYPE, discover.OPENID_2_0_TYPE]), + ] class TestIsOPIdentifier(unittest.TestCase): @@ -690,14 +685,16 @@ class TestIsOPIdentifier(unittest.TestCase): self.assertTrue(self.endpoint.isOPIdentifier()) def test_multipleMissing(self): - self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, - discover.OPENID_1_0_TYPE] + self.endpoint.type_uris = [ + discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE + ] self.assertFalse(self.endpoint.isOPIdentifier()) def test_multiplePresent(self): - self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, - discover.OPENID_1_0_TYPE, - discover.OPENID_IDP_2_0_TYPE] + self.endpoint.type_uris = [ + discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, + discover.OPENID_IDP_2_0_TYPE + ] self.assertTrue(self.endpoint.isOPIdentifier()) @@ -761,18 +758,19 @@ class TestEndpointSupportsType(unittest.TestCase): def failUnlessSupportsOnly(self, *types): for t in [ - 'foo', - discover.OPENID_1_1_TYPE, - discover.OPENID_1_0_TYPE, - discover.OPENID_2_0_TYPE, - discover.OPENID_IDP_2_0_TYPE, - ]: + 'foo', + discover.OPENID_1_1_TYPE, + discover.OPENID_1_0_TYPE, + discover.OPENID_2_0_TYPE, + discover.OPENID_IDP_2_0_TYPE, + ]: if t in types: - self.assertTrue(self.endpoint.supportsType(t), - "Must support %r" % (t,)) + self.assertTrue( + self.endpoint.supportsType(t), "Must support %r" % (t, )) else: - self.assertFalse(self.endpoint.supportsType(t), - "Shouldn't support %r" % (t,)) + self.assertFalse( + self.endpoint.supportsType(t), + "Shouldn't support %r" % (t, )) def test_supportsNothing(self): self.failUnlessSupportsOnly() @@ -795,19 +793,21 @@ class TestEndpointSupportsType(unittest.TestCase): self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE) def test_multiple(self): - self.endpoint.type_uris = [discover.OPENID_1_1_TYPE, - discover.OPENID_2_0_TYPE] + self.endpoint.type_uris = [ + discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE + ] self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE) def test_multipleWithProvider(self): - self.endpoint.type_uris = [discover.OPENID_1_1_TYPE, - discover.OPENID_2_0_TYPE, - discover.OPENID_IDP_2_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE, - discover.OPENID_2_0_TYPE, - discover.OPENID_IDP_2_0_TYPE, - ) + self.endpoint.type_uris = [ + discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE, + discover.OPENID_IDP_2_0_TYPE + ] + self.failUnlessSupportsOnly( + discover.OPENID_1_1_TYPE, + discover.OPENID_2_0_TYPE, + discover.OPENID_IDP_2_0_TYPE, ) class TestEndpointDisplayIdentifier(unittest.TestCase): @@ -827,9 +827,8 @@ class TestDiscoveryFailureDjangoAllAuth(unittest.TestCase): # DiscoveryFailure with self.assertRaises(DiscoveryFailure): auth_request = client.begin("http://www.google.com") - result = auth_request.redirectURL( - 'http://localhost/', - 'http://localhost/callback') + result = auth_request.redirectURL('http://localhost/', + 'http://localhost/callback') self.assertEquals(result, None) diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index fbf7f6ee83b4aee3ed3ba2c9a44e260ee8f19054..bca81aa12350c71fd21795a734a3d11176bdd3f7 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -7,6 +7,7 @@ def datapath(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) return os.path.join(module_directory, 'data', 'test_etxrd', filename) + XRD_FILE = datapath('valid-populated-xrds.xml') NOXRDS_FILE = datapath('not-xrds.xml') NOXRD_FILE = datapath('no-xrd.xml') @@ -23,8 +24,9 @@ def simpleOpenIDTransformer(endpoint): if 'http://openid.net/signon/1.0' not in endpoint.type_uris: return None - delegates = list(endpoint.service_element.findall( - '{http://openid.net/xmlns/1.0}Delegate')) + delegates = list( + endpoint.service_element.findall( + '{http://openid.net/xmlns/1.0}Delegate')) assert len(delegates) == 1 delegate = delegates[0].text return (endpoint.uri, delegate) @@ -52,7 +54,7 @@ class TestServiceParser(unittest.TestCase): ("http://www.schtuff.com/openid", "http://users.schtuff.com/josh"), ("http://www.livejournal.com/openid/server.bml", "http://www.livejournal.com/users/nedthealpaca/"), - ] + ] it = iter(services) for (server_url, delegate) in expectedServices: @@ -73,7 +75,7 @@ class TestServiceParser(unittest.TestCase): self.assertEqual(service.uri, uri) break else: - self.fail('Did not find %r service' % (type_uri,)) + self.fail('Did not find %r service' % (type_uri, )) def testGetSeveral(self): """Get some services in order""" @@ -81,15 +83,15 @@ class TestServiceParser(unittest.TestCase): # type, URL (TYPEKEY_1_0, None), (LID_2_0, "http://mylid.net/josh"), - ] + ] self._checkServices(expectedServices) def testGetSeveralForOne(self): """Getting services for one Service with several Type elements.""" - types = ['http://lid.netmesh.org/sso/2.0b5', - 'http://lid.netmesh.org/2.0b5' - ] + types = [ + 'http://lid.netmesh.org/sso/2.0b5', 'http://lid.netmesh.org/2.0b5' + ] uri = "http://mylid.net/josh" @@ -106,30 +108,26 @@ class TestServiceParser(unittest.TestCase): not present""" with open(NOXRDS_FILE, 'rb') as f: self.xmldoc = f.read() - self.assertRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, + self.yadis_url, self.xmldoc, None) def testEmpty(self): """Make sure that we get an exception when an XRDS element is not present""" self.xmldoc = '' - self.assertRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, + self.yadis_url, self.xmldoc, None) def testNoXRD(self): """Make sure that we get an exception when there is no XRD element present.""" with open(NOXRD_FILE, 'rb') as f: self.xmldoc = f.read() - self.assertRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, + self.yadis_url, self.xmldoc, None) class TestCanonicalID(unittest.TestCase): - def mkTest(iname, filename, expectedID): """This function builds a method that runs the CanonicalID test for the given set of inputs""" @@ -140,27 +138,23 @@ class TestCanonicalID(unittest.TestCase): with open(filename, 'rb') as f: xrds = etxrd.parseXRDS(f.read()) self._getCanonicalID(iname, xrds, expectedID) + return test - test_delegated = mkTest( - "@ootao*test1", "delegated-20060809.xrds", - "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") + test_delegated = mkTest("@ootao*test1", "delegated-20060809.xrds", + "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") - test_delegated_r1 = mkTest( - "@ootao*test1", "delegated-20060809-r1.xrds", - "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") + test_delegated_r1 = mkTest("@ootao*test1", "delegated-20060809-r1.xrds", + "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") - test_delegated_r2 = mkTest( - "@ootao*test1", "delegated-20060809-r2.xrds", - "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") + test_delegated_r2 = mkTest("@ootao*test1", "delegated-20060809-r2.xrds", + "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") - test_sometimesprefix = mkTest( - "@ootao*test1", "sometimesprefix.xrds", - "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") + test_sometimesprefix = mkTest("@ootao*test1", "sometimesprefix.xrds", + "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") - test_prefixsometimes = mkTest( - "@ootao*test1", "prefixsometimes.xrds", - "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") + test_prefixsometimes = mkTest("@ootao*test1", "prefixsometimes.xrds", + "@!5BAD.2AA.3C72.AF46!0000.0000.3B9A.CA01") test_spoof1 = mkTest("=keturn*isDrummond", "spoof1.xrds", etxrd.XRDSFraud) @@ -191,11 +185,10 @@ class TestCanonicalID(unittest.TestCase): cid = etxrd.getCanonicalID(iname, xrds) self.assertEqual(cid, expectedID and xri.XRI(expectedID)) elif issubclass(expectedID, etxrd.XRDSError): - self.assertRaises(expectedID, etxrd.getCanonicalID, - iname, xrds) + self.assertRaises(expectedID, etxrd.getCanonicalID, iname, xrds) else: - self.fail("Don't know how to test for expected value %r" - % (expectedID,)) + self.fail("Don't know how to test for expected value %r" % + (expectedID, )) if __name__ == '__main__': diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index e3db2842d96b232d6a89e41b7d3c30afc9a2e4d1..723539855fc24a109f85437324f82047c519582a 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -3,6 +3,7 @@ from openid import message import unittest + class DummyExtension(extension.Extension): ns_uri = 'http://an.extension/' ns_alias = 'dummy' @@ -10,6 +11,7 @@ class DummyExtension(extension.Extension): def getExtensionArgs(self): return {} + class ToMessageTest(unittest.TestCase): def test_OpenID1(self): oid1_msg = message.Message(message.OPENID1_NS) @@ -17,11 +19,10 @@ class ToMessageTest(unittest.TestCase): ext.toMessage(oid1_msg) namespaces = oid1_msg.namespaces self.assertTrue(namespaces.isImplicit(DummyExtension.ns_uri)) - self.assertEqual( - DummyExtension.ns_uri, - namespaces.getNamespaceURI(DummyExtension.ns_alias)) + self.assertEqual(DummyExtension.ns_uri, + namespaces.getNamespaceURI(DummyExtension.ns_alias)) self.assertEqual(DummyExtension.ns_alias, - namespaces.getAlias(DummyExtension.ns_uri)) + namespaces.getAlias(DummyExtension.ns_uri)) def test_OpenID2(self): oid2_msg = message.Message(message.OPENID2_NS) @@ -29,8 +30,7 @@ class ToMessageTest(unittest.TestCase): ext.toMessage(oid2_msg) namespaces = oid2_msg.namespaces self.assertFalse(namespaces.isImplicit(DummyExtension.ns_uri)) - self.assertEqual( - DummyExtension.ns_uri, - namespaces.getNamespaceURI(DummyExtension.ns_alias)) + self.assertEqual(DummyExtension.ns_uri, + namespaces.getNamespaceURI(DummyExtension.ns_alias)) self.assertEqual(DummyExtension.ns_alias, - namespaces.getAlias(DummyExtension.ns_uri)) + namespaces.getAlias(DummyExtension.ns_uri)) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index c97ffbaf2417f1a003e8cf659db098c2519adcb8..9baac49fb4e1ba034d59b3c678a17a8487379477 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -30,7 +30,6 @@ def failUnlessResponseExpected(expected, actual, extra): def test_fetcher(fetcher, should_raise_exc, server): - def geturl(path): host, port = server.server_address return 'http://%s:%s%s' % (host, port, path) @@ -56,7 +55,7 @@ def test_fetcher(fetcher, should_raise_exc, server): plain('forbidden', 403), plain('error', 500), plain('server_error', 503), - ] + ] for path, expected in cases: fetch_url = geturl(path) @@ -74,7 +73,7 @@ def test_fetcher(fetcher, should_raise_exc, server): 'http://invalid.janrain.com/', 'not:a/url', 'ftp://janrain.com/pub/', - ]: + ]: try: result = fetcher.fetch(err_url) except (KeyboardInterrupt, SystemExit): @@ -89,8 +88,8 @@ def test_fetcher(fetcher, should_raise_exc, server): except Exception as e: assert should_raise_exc else: - assert False, 'An exception was expected for %r (%r)' % ( - fetcher, result) + assert False, 'An exception was expected for %r (%r)' % (fetcher, + result) def run_fetcher_tests(server): @@ -99,20 +98,21 @@ def run_fetcher_tests(server): (fetchers.Urllib2Fetcher, 'urllib2'), (fetchers.CurlHTTPFetcher, 'pycurl'), (fetchers.HTTPLib2Fetcher, 'httplib2'), - ]: + ]: try: exc_fetchers.append(klass()) except RuntimeError as why: - if str(why).startswith('Cannot find %s library' % (library_name,)): + if str(why).startswith('Cannot find %s library' % + (library_name, )): try: __import__(library_name) except ImportError: raise unittest.SkipTest( 'Skipping tests for %r fetcher because ' - 'the library did not import.' % (library_name,)) + 'the library did not import.' % (library_name, )) else: - assert False, ('%s present but not detected' % ( - library_name,)) + assert False, ('%s present but not detected' % + (library_name, )) else: raise @@ -126,6 +126,7 @@ def run_fetcher_tests(server): for f in non_exc_fetchers: test_fetcher(f, False, server) + from http.server import BaseHTTPRequestHandler, HTTPServer @@ -141,7 +142,7 @@ class FetcherTestHandler(BaseHTTPRequestHandler): '/forbidden': (403, None), '/error': (500, None), '/server_error': (503, None), - } + } def log_request(self, *args): pass @@ -161,7 +162,7 @@ class FetcherTestHandler(BaseHTTPRequestHandler): extra_headers = [('Content-type', 'text/plain')] if location is not None: host, port = self.server.server_address - base = ('http://%s:%s' % (host, port,)) + base = ('http://%s:%s' % (host, port, )) location = base + location extra_headers.append(('Location', location)) self._respond(http_code, extra_headers, self.path) @@ -185,7 +186,7 @@ class FetcherTestHandler(BaseHTTPRequestHandler): req = [ ('HTTP method', self.command), ('path', self.path), - ] + ] if message: req.append(('message', message)) @@ -269,11 +270,11 @@ class DefaultFetcherTest(unittest.TestCase): """Make sure that the default fetcher instance wraps exceptions by default""" default_fetcher = fetchers.getDefaultFetcher() - self.assertIsInstance( - default_fetcher, fetchers.ExceptionWrappingFetcher) + self.assertIsInstance(default_fetcher, + fetchers.ExceptionWrappingFetcher) - self.assertRaises(fetchers.HTTPFetchingError, - fetchers.fetch, 'http://invalid.janrain.com/') + self.assertRaises(fetchers.HTTPFetchingError, fetchers.fetch, + 'http://invalid.janrain.com/') def test_notWrapped(self): """Make sure that if we set a non-wrapped fetcher as default, @@ -283,8 +284,9 @@ class DefaultFetcherTest(unittest.TestCase): fetcher = fetchers.Urllib2Fetcher() fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False) - self.assertFalse(isinstance(fetchers.getDefaultFetcher(), - fetchers.ExceptionWrappingFetcher)) + self.assertFalse( + isinstance(fetchers.getDefaultFetcher(), + fetchers.ExceptionWrappingFetcher)) try: fetchers.fetch('http://invalid.janrain.com/') @@ -299,6 +301,7 @@ class DefaultFetcherTest(unittest.TestCase): class Urllib2FetcherTests(unittest.TestCase): '''Make sure a few of the utility methods are also covered by tests.''' + def setUp(self): self.fetcher = fetchers.Urllib2Fetcher() @@ -306,8 +309,10 @@ class Urllib2FetcherTests(unittest.TestCase): ''' Test that the _allowedURL function only lets through the right things. ''' - for url in ["file://localhost/thing.txt", "ftp://server/path", - "sftp://server/path", "ssh://server/path"]: + for url in [ + "file://localhost/thing.txt", "ftp://server/path", + "sftp://server/path", "ssh://server/path" + ]: self.assertEqual(fetchers._allowedURL(url), False) def test_lowerCaseKeys(self): @@ -317,11 +322,16 @@ class Urllib2FetcherTests(unittest.TestCase): def test_parseHeaderValue(self): headers_parsed = [ - ("text/html; charset=latin-1", - ("text/html", {"charset": "latin-1"})), - ("1; mode=block", ("1", {"mode": "block"})), - ("foo; bar=baz; thing=quux", - ("foo", {"bar": "baz", "thing": "quux"})), + ("text/html; charset=latin-1", ("text/html", { + "charset": "latin-1" + })), + ("1; mode=block", ("1", { + "mode": "block" + })), + ("foo; bar=baz; thing=quux", ("foo", { + "bar": "baz", + "thing": "quux" + })), ] for s, p in headers_parsed: self.assertEqual(self.fetcher._parseHeaderValue(s), p) diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index cc150638991de090263ff262357bf2fcf81fd5d6..40fa22df3b9fb6bf9ab160c4f0b4175a2222dac1 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -14,7 +14,8 @@ class BadLinksTestCase(datadriven.DataDrivenTestCase): self.data = data def runOneTest(self): - actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', self.data) + actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', + self.data) expected = [] self.assertEqual(expected, actual) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 7c61763138bf1c61698b6baa9426e28900cd8be4..8f543cc02c724990f492ac15c9ca59445405a028 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -14,13 +14,11 @@ def mkGetArgTest(ns, key, expected=None): a_default = object() self.assertEqual(self.msg.getArg(ns, key), expected) if expected is None: - self.assertEqual( - self.msg.getArg(ns, key, a_default), a_default) - self.assertRaises( - KeyError, self.msg.getArg, ns, key, message.no_default) + self.assertEqual(self.msg.getArg(ns, key, a_default), a_default) + self.assertRaises(KeyError, self.msg.getArg, ns, key, + message.no_default) else: - self.assertEqual( - self.msg.getArg(ns, key, a_default), expected) + self.assertEqual(self.msg.getArg(ns, key, a_default), expected) self.assertEqual( self.msg.getArg(ns, key, message.no_default), expected) @@ -55,8 +53,8 @@ class EmptyMessageTest(unittest.TestCase): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.getKey, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getKey, + message.OPENID_NS, 'foo') def test_getKeyBARE(self): self.assertEqual(self.msg.getKey(message.BARE_NS, 'foo'), 'foo') @@ -68,16 +66,16 @@ class EmptyMessageTest(unittest.TestCase): self.assertEqual(self.msg.getKey(message.OPENID2_NS, 'foo'), None) def test_getKeyNS3(self): - self.assertEqual(self.msg.getKey('urn:nothing-significant', 'foo'), - None) + self.assertEqual( + self.msg.getKey('urn:nothing-significant', 'foo'), None) def test_hasKey(self): # Could reasonably return False instead of raising an # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.hasKey, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.hasKey, + message.OPENID_NS, 'foo') def test_hasKeyBARE(self): self.assertEqual(self.msg.hasKey(message.BARE_NS, 'foo'), False) @@ -89,27 +87,29 @@ class EmptyMessageTest(unittest.TestCase): self.assertEqual(self.msg.hasKey(message.OPENID2_NS, 'foo'), False) def test_hasKeyNS3(self): - self.assertEqual(self.msg.hasKey('urn:nothing-significant', 'foo'), - False) + self.assertEqual( + self.msg.hasKey('urn:nothing-significant', 'foo'), False) def test_getAliasedArgSuccess(self): - msg = message.Message.fromPostArgs({'openid.ns.test': 'urn://foo', - 'openid.test.flub': 'bogus'}) + msg = message.Message.fromPostArgs({ + 'openid.ns.test': 'urn://foo', + 'openid.test.flub': 'bogus' + }) actual_uri = msg.getAliasedArg('ns.test', message.no_default) self.assertEqual("urn://foo", actual_uri) def test_getAliasedArgFailure(self): msg = message.Message.fromPostArgs({'openid.test.flub': 'bogus'}) - self.assertRaises(KeyError, - msg.getAliasedArg, 'ns.test', message.no_default) + self.assertRaises(KeyError, msg.getAliasedArg, 'ns.test', + message.no_default) def test_getArg(self): # Could reasonably return None instead of raising an # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.getArg, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArg, + message.OPENID_NS, 'foo') test_getArgBARE = mkGetArgTest(message.BARE_NS, 'foo') test_getArgNS1 = mkGetArgTest(message.OPENID1_NS, 'foo') @@ -121,8 +121,8 @@ class EmptyMessageTest(unittest.TestCase): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.getArgs, message.OPENID_NS) + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArgs, + message.OPENID_NS) def test_getArgsBARE(self): self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) @@ -138,8 +138,8 @@ class EmptyMessageTest(unittest.TestCase): def test_updateArgs(self): self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.updateArgs, message.OPENID_NS, - {'does not': 'matter'}) + self.msg.updateArgs, message.OPENID_NS, + {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { @@ -164,9 +164,8 @@ class EmptyMessageTest(unittest.TestCase): self._test_updateArgsNS('urn:nothing-significant') def test_setArg(self): - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.setArg, message.OPENID_NS, - 'does not', 'matter') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.setArg, + message.OPENID_NS, 'does not', 'matter') def _test_setArgNS(self, ns): key = 'Camper van Beethoven' @@ -188,8 +187,8 @@ class EmptyMessageTest(unittest.TestCase): self._test_setArgNS('urn:nothing-significant') def test_setArgToNone(self): - self.assertRaises(AssertionError, self.msg.setArg, - message.OPENID1_NS, 'op_endpoint', None) + self.assertRaises(AssertionError, self.msg.setArg, message.OPENID1_NS, + 'op_endpoint', None) def test_delArg(self): # Could reasonably raise KeyError instead of raising @@ -197,8 +196,8 @@ class EmptyMessageTest(unittest.TestCase): # right, since this case should only happen when you're # building a message from scratch and so have no default # namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, - self.msg.delArg, message.OPENID_NS, 'key') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.delArg, + message.OPENID_NS, 'key') def _test_delArgNS(self, ns): key = 'Camper van Beethoven' @@ -231,20 +230,17 @@ class OpenID1MessageTest(unittest.TestCase): }) def test_toPostArgs(self): - self.assertEqual(self.msg.toPostArgs(), { - 'openid.mode': 'error', - 'openid.error': 'unit test' - }) + self.assertEqual(self.msg.toPostArgs(), + {'openid.mode': 'error', + 'openid.error': 'unit test'}) def test_toArgs(self): - self.assertEqual(self.msg.toArgs(), { - 'mode': 'error', - 'error': 'unit test' - }) + self.assertEqual(self.msg.toArgs(), + {'mode': 'error', + 'error': 'unit test'}) def test_toKVForm(self): - self.assertEqual(self.msg.toKVForm(), - b'error:unit test\nmode:error\n') + self.assertEqual(self.msg.toKVForm(), b'error:unit test\nmode:error\n') def test_toURLEncoded(self): self.assertEqual(self.msg.toURLEncoded(), @@ -258,17 +254,16 @@ class OpenID1MessageTest(unittest.TestCase): self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = urllib.parse.parse_qs(query) - self.assertEqual(parsed, { - 'openid.mode': ['error'], - 'openid.error': ['unit test'] - }) + self.assertEqual( + parsed, {'openid.mode': ['error'], + 'openid.error': ['unit test']}) def test_getOpenID(self): self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) def test_getKeyOpenID(self): - self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), - 'openid.mode') + self.assertEqual( + self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') @@ -294,8 +289,7 @@ class OpenID1MessageTest(unittest.TestCase): self.assertEqual(self.msg.hasKey(message.OPENID1_NS, 'mode'), True) def test_hasKeyNS2(self): - self.assertEqual( - self.msg.hasKey(message.OPENID2_NS, 'mode'), False) + self.assertEqual(self.msg.hasKey(message.OPENID2_NS, 'mode'), False) def test_hasKeyNS3(self): self.assertEqual( @@ -308,19 +302,21 @@ class OpenID1MessageTest(unittest.TestCase): test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgs(self): - self.assertEqual(self.msg.getArgs(message.OPENID_NS), { - 'mode': 'error', - 'error': 'unit test', - }) + self.assertEqual( + self.msg.getArgs(message.OPENID_NS), { + 'mode': 'error', + 'error': 'unit test', + }) def test_getArgsBARE(self): self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) def test_getArgsNS1(self): - self.assertEqual(self.msg.getArgs(message.OPENID1_NS), { - 'mode': 'error', - 'error': 'unit test', - }) + self.assertEqual( + self.msg.getArgs(message.OPENID1_NS), { + 'mode': 'error', + 'error': 'unit test', + }) def test_getArgsNS2(self): self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {}) @@ -334,7 +330,7 @@ class OpenID1MessageTest(unittest.TestCase): update_args = { 'Camper van Beethoven': 'David Lowery', 'Magnolia Electric Co.': 'Jason Molina', - } + } self.assertEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) @@ -343,16 +339,16 @@ class OpenID1MessageTest(unittest.TestCase): self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgs(self): - self._test_updateArgsNS(message.OPENID_NS, - before={'mode': 'error', - 'error': 'unit test'}) + self._test_updateArgsNS( + message.OPENID_NS, before={'mode': 'error', + 'error': 'unit test'}) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS) def test_updateArgsNS1(self): - self._test_updateArgsNS(message.OPENID1_NS, - before={'mode': 'error', + self._test_updateArgsNS( + message.OPENID1_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS2(self): @@ -438,15 +434,17 @@ class OpenID1ExplicitMessageTest(unittest.TestCase): }) def test_toKVForm(self): - self.assertEqual( - self.msg.toKVForm(), - bytes('error:unit test\nmode:error\nns:%s\n' % - message.OPENID1_NS, encoding="utf-8")) + self.assertEqual(self.msg.toKVForm(), + bytes( + 'error:unit test\nmode:error\nns:%s\n' % + message.OPENID1_NS, + encoding="utf-8")) def test_toURLEncoded(self): self.assertEqual( self.msg.toURLEncoded(), - 'openid.error=unit+test&openid.mode=error&openid.ns=http%3A%2F%2Fopenid.net%2Fsignon%2F1.0') + 'openid.error=unit+test&openid.mode=error&openid.ns=http%3A%2F%2Fopenid.net%2Fsignon%2F1.0' + ) def test_toURL(self): base_url = 'http://base.url/' @@ -476,26 +474,26 @@ class OpenID2MessageTest(unittest.TestCase): self.msg.setArg(message.BARE_NS, "xey", "value") def test_toPostArgs(self): - self.assertEqual( - self.msg.toPostArgs(), { - 'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS, - 'xey': 'value', - }) + self.assertEqual(self.msg.toPostArgs(), { + 'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, + 'xey': 'value', + }) def test_toPostArgs_bug_with_utf8_encoded_values(self): - msg = message.Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS - }) + msg = message.Message.fromPostArgs({ + 'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS + }) msg.setArg(message.BARE_NS, 'ünicöde_key', 'ünicöde_välüe') - self.assertEqual(msg.toPostArgs(), - {'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS, - 'ünicöde_key': 'ünicöde_välüe', - }) + self.assertEqual(msg.toPostArgs(), { + 'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, + 'ünicöde_key': 'ünicöde_välüe', + }) def test_toArgs(self): # This method can't tolerate BARE_NS. @@ -509,15 +507,16 @@ class OpenID2MessageTest(unittest.TestCase): def test_toKVForm(self): # Can't tolerate BARE_NS in kvform self.msg.delArg(message.BARE_NS, "xey") - self.assertEqual( - self.msg.toKVForm(), - bytes('error:unit test\nmode:error\nns:%s\n' % - message.OPENID2_NS, encoding="utf-8")) + self.assertEqual(self.msg.toKVForm(), + bytes( + 'error:unit test\nmode:error\nns:%s\n' % + message.OPENID2_NS, + encoding="utf-8")) def _test_urlencoded(self, s): expected = ('openid.error=unit+test&openid.mode=error&' - 'openid.ns=%s&xey=value' % ( - urllib.parse.quote(message.OPENID2_NS, ''),)) + 'openid.ns=%s&xey=value' % + (urllib.parse.quote(message.OPENID2_NS, ''), )) self.assertEqual(s, expected) def test_toURLEncoded(self): @@ -536,15 +535,14 @@ class OpenID2MessageTest(unittest.TestCase): self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID2_NS) def test_getKeyOpenID(self): - self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), - 'openid.mode') + self.assertEqual( + self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') def test_getKeyNS1(self): - self.assertEqual( - self.msg.getKey(message.OPENID1_NS, 'mode'), None) + self.assertEqual(self.msg.getKey(message.OPENID1_NS, 'mode'), None) def test_getKeyNS2(self): self.assertEqual( @@ -561,12 +559,10 @@ class OpenID2MessageTest(unittest.TestCase): self.assertEqual(self.msg.hasKey(message.BARE_NS, 'mode'), False) def test_hasKeyNS1(self): - self.assertEqual( - self.msg.hasKey(message.OPENID1_NS, 'mode'), False) + self.assertEqual(self.msg.hasKey(message.OPENID1_NS, 'mode'), False) def test_hasKeyNS2(self): - self.assertEqual( - self.msg.hasKey(message.OPENID2_NS, 'mode'), True) + self.assertEqual(self.msg.hasKey(message.OPENID2_NS, 'mode'), True) def test_hasKeyNS3(self): self.assertEqual( @@ -579,23 +575,24 @@ class OpenID2MessageTest(unittest.TestCase): test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgsOpenID(self): - self.assertEqual(self.msg.getArgs(message.OPENID_NS), { - 'mode': 'error', - 'error': 'unit test', - }) + self.assertEqual( + self.msg.getArgs(message.OPENID_NS), { + 'mode': 'error', + 'error': 'unit test', + }) def test_getArgsBARE(self): - self.assertEqual(self.msg.getArgs(message.BARE_NS), - {'xey': 'value'}) + self.assertEqual(self.msg.getArgs(message.BARE_NS), {'xey': 'value'}) def test_getArgsNS1(self): self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {}) def test_getArgsNS2(self): - self.assertEqual(self.msg.getArgs(message.OPENID2_NS), { - 'mode': 'error', - 'error': 'unit test', - }) + self.assertEqual( + self.msg.getArgs(message.OPENID2_NS), { + 'mode': 'error', + 'error': 'unit test', + }) def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) @@ -615,20 +612,19 @@ class OpenID2MessageTest(unittest.TestCase): self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgsOpenID(self): - self._test_updateArgsNS(message.OPENID_NS, - before={'mode': 'error', - 'error': 'unit test'}) + self._test_updateArgsNS( + message.OPENID_NS, before={'mode': 'error', + 'error': 'unit test'}) def test_updateArgsBARE(self): - self._test_updateArgsNS(message.BARE_NS, - before={'xey': 'value'}) + self._test_updateArgsNS(message.BARE_NS, before={'xey': 'value'}) def test_updateArgsNS1(self): self._test_updateArgsNS(message.OPENID1_NS) def test_updateArgsNS2(self): - self._test_updateArgsNS(message.OPENID2_NS, - before={'mode': 'error', + self._test_updateArgsNS( + message.OPENID2_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS3(self): @@ -661,37 +657,47 @@ class OpenID2MessageTest(unittest.TestCase): allowed as namespace aliases.""" for f in message.OPENID_PROTOCOL_FIELDS + ['dotted.alias']: - args = {'openid.ns.%s' % f: 'blah', - 'openid.%s.foo' % f: 'test'} + args = {'openid.ns.%s' % f: 'blah', 'openid.%s.foo' % f: 'test'} # .fromPostArgs covers .fromPostArgs, .fromOpenIDArgs, # ._fromOpenIDArgs, and .fromOpenIDArgs (since it calls # .fromPostArgs). - self.assertRaises(AssertionError, self.msg.fromPostArgs, - args) + self.assertRaises(AssertionError, self.msg.fromPostArgs, args) def test_mysterious_missing_namespace_bug(self): """A failing test for bug #112""" openid_args = { - 'assoc_handle': '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', - 'claimed_id': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'ns.sreg': 'http://openid.net/extensions/sreg/1.1', - 'response_nonce': '2008-05-22T17:27:22ZUoW5.\\NV', - 'signed': 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,assoc_handle', - 'sig': 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', - 'mode': 'check_authentication', - 'op_endpoint': 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', - 'sreg.nickname': 'Andy', - 'return_to': 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', - 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', - 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'sreg.email': 'a@b.com' + 'assoc_handle': + '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', + 'claimed_id': + 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'ns.sreg': + 'http://openid.net/extensions/sreg/1.1', + 'response_nonce': + '2008-05-22T17:27:22ZUoW5.\\NV', + 'signed': + 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,assoc_handle', + 'sig': + 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', + 'mode': + 'check_authentication', + 'op_endpoint': + 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', + 'sreg.nickname': + 'Andy', + 'return_to': + 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', + 'invalidate_handle': + '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', + 'identity': + 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'sreg.email': + 'a@b.com' } m = message.Message.fromOpenIDArgs(openid_args) - self.assertTrue( - ('http://openid.net/extensions/sreg/1.1', 'sreg') in - m.namespaces.items()) + self.assertTrue(('http://openid.net/extensions/sreg/1.1', + 'sreg') in m.namespaces.items()) missing = [] if isinstance(openid_args['signed'], bytes): oid_args_signed = openid_args['signed'].decode("utf-8") @@ -706,20 +712,34 @@ class OpenID2MessageTest(unittest.TestCase): def test_112B(self): args = { - 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' + 'openid.assoc_handle': + 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': + 'http://binkley.lan/user/test01', + 'openid.identity': + 'http://test01.binkley.lan/', + 'openid.mode': + 'id_res', + 'openid.ns': + 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': + 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': + 'http://binkley.lan/server', + 'openid.pape.auth_policies': + 'none', + 'openid.pape.auth_time': + '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': + '0', + 'openid.response_nonce': + '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': + 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': + 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': + 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' } m = message.Message.fromPostArgs(args) missing = [] @@ -735,12 +755,9 @@ class OpenID2MessageTest(unittest.TestCase): self.assertTrue(m.isOpenID2()) def test_implicit_sreg_ns(self): - openid_args = { - 'sreg.email': 'a@b.com' - } + openid_args = {'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) - self.assertTrue((sreg.ns_uri, 'sreg') in - list(m.namespaces.items())) + self.assertTrue((sreg.ns_uri, 'sreg') in list(m.namespaces.items())) self.assertEqual('a@b.com', m.getArg(sreg.ns_uri, 'email')) self.assertEqual(openid_args, m.toArgs()) self.assertTrue(m.isOpenID1()) @@ -782,8 +799,7 @@ class OpenID2MessageTest(unittest.TestCase): self.assertTrue(self.msg.getArg(ns, key) == value_2) def test_argList(self): - self.assertRaises(TypeError, self.msg.fromPostArgs, - {'arg': [1, 2, 3]}) + self.assertRaises(TypeError, self.msg.fromPostArgs, {'arg': [1, 2, 3]}) def test_isOpenID1(self): self.assertFalse(self.msg.isOpenID1()) @@ -800,14 +816,14 @@ class MessageTest(unittest.TestCase): 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', - } + } self.action_url = 'scheme://host:port/path?query' self.form_tag_attrs = { 'company': 'janrain', 'class': 'fancyCSS', - } + } self.submit_text = 'GO!' @@ -817,7 +833,7 @@ class MessageTest(unittest.TestCase): 'accept-charset': 'UTF-8', 'enctype': 'application/x-www-form-urlencoded', 'method': 'post', - } + } def _checkForm(self, html, message_, action_url, form_tag_attrs, submit_text): @@ -866,7 +882,7 @@ class MessageTest(unittest.TestCase): (e.attrib['name'], value, e.attrib['value']) break else: - self.fail("Post arg '%s' not found in form" % (name,)) + self.fail("Post arg '%s' not found in form" % (name, )) for e in hiddens: assert e.attrib['name'] in list(message_.toPostArgs().keys()), \ @@ -895,8 +911,8 @@ class MessageTest(unittest.TestCase): m = message.Message.fromPostArgs(self.postargs) html = m.toFormMarkup(self.action_url, self.form_tag_attrs, self.submit_text) - self._checkForm(html, m, self.action_url, - self.form_tag_attrs, self.submit_text) + self._checkForm(html, m, self.action_url, self.form_tag_attrs, + self.submit_text) def test_toFormMarkup_bug_with_utf8_values(self): postargs = { @@ -906,7 +922,7 @@ class MessageTest(unittest.TestCase): 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', 'ünicöde_key': 'ünicöde_välüe', - } + } m = message.Message.fromPostArgs(postargs) # Calling m.toFormMarkup with lxml used for ElementTree will throw # a ValueError. @@ -920,8 +936,8 @@ class MessageTest(unittest.TestCase): 'ünicöde_key' in html, 'UTF-8 bytes should not convert to XML character references') self.assertFalse( - 'ünicöde_välüe' in html, - 'UTF-8 bytes should not convert to XML character references') + 'ünicöde_välüe' in + html, 'UTF-8 bytes should not convert to XML character references') def test_overrideMethod(self): """Be sure that caller cannot change form method to GET.""" @@ -932,8 +948,8 @@ class MessageTest(unittest.TestCase): html = m.toFormMarkup(self.action_url, self.form_tag_attrs, self.submit_text) - self._checkForm(html, m, self.action_url, - self.form_tag_attrs, self.submit_text) + self._checkForm(html, m, self.action_url, self.form_tag_attrs, + self.submit_text) def test_overrideRequired(self): """Be sure that caller CANNOT change the form charset for @@ -944,10 +960,8 @@ class MessageTest(unittest.TestCase): tag_attrs['accept-charset'] = 'UCS4' tag_attrs['enctype'] = 'invalid/x-broken' - html = m.toFormMarkup(self.action_url, tag_attrs, - self.submit_text) - self._checkForm(html, m, self.action_url, - tag_attrs, self.submit_text) + html = m.toFormMarkup(self.action_url, tag_attrs, self.submit_text) + self._checkForm(html, m, self.action_url, tag_attrs, self.submit_text) def test_setOpenIDNamespace_invalid(self): m = message.Message() @@ -962,23 +976,23 @@ class MessageTest(unittest.TestCase): 'http%3A%2F%2Fspecs.openid.net%2Fauth%2F2.0', # This is a Type URI, not a openid.ns value. 'http://specs.openid.net/auth/2.0/signon', - ] + ] for x in invalid_things: self.assertRaises(message.InvalidOpenIDNamespace, - m.setOpenIDNamespace, x, False) + m.setOpenIDNamespace, x, False) def test_isOpenID1(self): v1_namespaces = [ # Yes, there are two of them. 'http://openid.net/signon/1.1', 'http://openid.net/signon/1.0', - ] + ] for ns in v1_namespaces: m = message.Message(ns) - self.assertTrue(m.isOpenID1(), "%r not recognized as OpenID 1" % - (ns,)) + self.assertTrue(m.isOpenID1(), + "%r not recognized as OpenID 1" % (ns, )) self.assertEqual(ns, m.getOpenIDNamespace()) self.assertTrue( m.namespaces.isImplicit(ns), @@ -1023,7 +1037,7 @@ class MessageTest(unittest.TestCase): 'openid.return_to': 'http://drupal.invalid/return_to', 'openid.sreg.required': 'nickname,email', 'openid.trust_root': 'http://drupal.invalid', - } + } m = message.Message.fromPostArgs(query) self.assertTrue(m.isOpenID1()) diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index f6806a29768af0e026069e3eb585891235beef41..c1eaadc21bbe0674e457bda75b6d8cc94fbd989b 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,4 +1,3 @@ - import unittest from .support import CatchLogs @@ -7,6 +6,7 @@ from openid import association from openid.consumer.consumer import GenericConsumer, ServerError from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_2_0_TYPE + class ErrorRaisingConsumer(GenericConsumer): """ A consumer whose _requestAssocation will return predefined results @@ -27,11 +27,13 @@ class ErrorRaisingConsumer(GenericConsumer): else: return m + class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): """ Test the session type negotiation behavior of an OpenID 2 consumer. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -45,9 +47,13 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): Test the case where the response to an associate request is a server error or is otherwise undecipherable. """ - self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + self.consumer.return_messages = [ + Message(self.endpoint.preferredNamespace()) + ] + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) + self.failUnlessLogMatches( + 'Server error when requesting an association') def testEmptyAssocType(self): """ @@ -61,11 +67,13 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + self.failUnlessLogMatches( + 'Unsupported association type', + 'Server responded with unsupported association ' + + 'session but did not supply a fallback.') def testEmptySessionType(self): """ @@ -79,11 +87,13 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + self.failUnlessLogMatches( + 'Unsupported association type', + 'Server responded with unsupported association ' + + 'session but did not supply a fallback.') def testNotAllowed(self): """ @@ -103,10 +113,12 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Unsupported association type', - 'Server sent unsupported session/association type:') + self.failUnlessLogMatches( + 'Unsupported association type', + 'Server sent unsupported session/association type:') def testUnsupportedWithRetry(self): """ @@ -119,11 +131,12 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', 'issued', 10000, + 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertTrue(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertTrue( + self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogMatches('Unsupported association type') @@ -138,26 +151,31 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') - self.consumer.return_messages = [msg, - Message(self.endpoint.preferredNamespace())] + self.consumer.return_messages = [ + msg, Message(self.endpoint.preferredNamespace()) + ] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) self.failUnlessLogMatches('Unsupported association type', - 'Server %s refused' % (self.endpoint.server_url)) + 'Server %s refused' % + (self.endpoint.server_url)) def testValid(self): """ Test the valid case, wherein an association is returned on the first attempt to get one. """ - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', 'issued', 10000, + 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertTrue(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertTrue( + self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): """ Tests for the OpenID 1 consumer association session behavior. See @@ -168,6 +186,7 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): these tests pass openid2-style messages to the openid 1 association processing logic to be sure it ignores the extra data. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -177,9 +196,13 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): self.endpoint.server_url = 'bogus' def testBadResponse(self): - self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + self.consumer.return_messages = [ + Message(self.endpoint.preferredNamespace()) + ] + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) + self.failUnlessLogMatches( + 'Server error when requesting an association') def testEmptyAssocType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -189,9 +212,11 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + self.failUnlessLogMatches( + 'Server error when requesting an association') def testEmptySessionType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -201,9 +226,11 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + self.failUnlessLogMatches( + 'Server error when requesting an association') def testNotAllowed(self): allowed_types = [] @@ -218,9 +245,11 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertEqual( + self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + self.failUnlessLogMatches( + 'Server error when requesting an association') def testUnsupportedWithRetry(self): msg = Message(self.endpoint.preferredNamespace()) @@ -229,28 +258,32 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', 'issued', 10000, + 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertTrue(self.consumer._negotiateAssociation(self.endpoint) is None) + self.assertTrue( + self.consumer._negotiateAssociation(self.endpoint) is None) - self.failUnlessLogMatches('Server error when requesting an association') + self.failUnlessLogMatches( + 'Server error when requesting an association') def testValid(self): - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', 'secret', 'issued', 10000, + 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertTrue(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertTrue( + self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): def setUp(self): self.allowed_types = [ ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'no-encryption'), - ] + ] self.n = association.SessionNegotiator(self.allowed_types) @@ -258,7 +291,8 @@ class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): self.assertRaises(ValueError, self.n.addAllowedType, 'invalid') def testAddAllowedTypeBadSessionType(self): - self.assertRaises(ValueError, self.n.addAllowedType, 'assoc1', 'invalid') + self.assertRaises(ValueError, self.n.addAllowedType, 'assoc1', + 'invalid') def testAddAllowedTypeContents(self): assoc_type = 'HMAC-SHA1' @@ -267,5 +301,6 @@ class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): for typ in association.getSessionTypes(assoc_type): self.assertTrue((assoc_type, typ) in self.n.allowed_types) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 8c84aed5a6de30b85883493b0aadc96d3b1df5cc..95ec99b956431ab1b1b9caa01e1584588ac8c539 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -10,6 +10,7 @@ from openid.store.nonce import \ nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') + class NonceTest(unittest.TestCase): def test_mkNonce(self): nonce = mkNonce() @@ -38,6 +39,7 @@ class NonceTest(unittest.TestCase): self.assertEqual(len(salt), 6) self.assertEqual(et, t) + class BadSplitTest(datadriven.DataDrivenTestCase): cases = [ '', @@ -47,7 +49,7 @@ class BadSplitTest(datadriven.DataDrivenTestCase): '1970.01-01T00:00:00Z', 'Thu Sep 7 13:29:31 PDT 2006', 'monkeys', - ] + ] def __init__(self, nonce_str): datadriven.DataDrivenTestCase.__init__(self, nonce_str) @@ -56,6 +58,7 @@ class BadSplitTest(datadriven.DataDrivenTestCase): def runOneTest(self): self.assertRaises(ValueError, splitNonce, self.nonce_str) + class CheckTimestampTest(datadriven.DataDrivenTestCase): cases = [ # exact, no allowed skew @@ -81,7 +84,7 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): # malformed nonce string ('monkeys', 0, 0, False), - ] + ] def __init__(self, nonce_string, allowed_skew, now, expected): datadriven.DataDrivenTestCase.__init__( @@ -95,9 +98,11 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): actual = checkTimestamp(self.nonce_string, self.allowed_skew, self.now) self.assertEqual(bool(self.expected), bool(actual)) + def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index e482d0cc52eb75807411c09089a1a814b6f1d541..78d8053b7f7f1f24b18f8f1b4ce4182a54ccc872 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -4,7 +4,6 @@ from openid.consumer.discover import \ from openid.yadis.services import applyFilter - XRDS_BOILERPLATE = '''\ <?xml version="1.0" encoding="UTF-8"?> <xrds:XRDS xmlns:xrds="xri://$xrds" @@ -16,10 +15,12 @@ XRDS_BOILERPLATE = '''\ </xrds:XRDS> ''' + def mkXRDS(services): - xrds = XRDS_BOILERPLATE % (services,) + xrds = XRDS_BOILERPLATE % (services, ) return xrds.encode('utf-8') + def mkService(uris=None, type_uris=None, local_id=None, dent=' '): chunks = [dent, '<Service>\n'] dent2 = dent + ' ' @@ -47,16 +48,19 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): return ''.join(chunks) + # Different sets of server URLs for use in the URI tag server_url_options = [ - [], # This case should not generate an endpoint object + [], # This case should not generate an endpoint object ['http://server.url/'], ['https://server.url/'], ['https://server.url/', 'http://server.url/'], - ['https://server.url/', - 'http://server.url/', - 'http://example.server.url/'], - ] + [ + 'https://server.url/', 'http://server.url/', + 'http://example.server.url/' + ], +] + # Used for generating test data def subsets(l): @@ -66,12 +70,13 @@ def subsets(l): subsets_list += [[x] + t for t in subsets_list] return subsets_list + # A couple of example extension type URIs. These are not at all # official, but are just here for testing. ext_types = [ 'http://janrain.com/extension/blah', 'http://openid.net/sreg/1.0', - ] +] # All valid combinations of Type tags that should produce an OpenID endpoint type_uri_options = [ @@ -83,22 +88,20 @@ type_uri_options = [ # All combinations of extension types (including empty extenstion list) for exts in subsets(ext_types) - ] +] # Range of valid Delegate tag values for generating test data local_id_options = [ None, 'http://vanity.domain/', 'https://somewhere/yadis/', - ] +] # All combinations of valid URIs, Type URIs and Delegate tags -data = [ - (uris, type_uris, local_id) - for uris in server_url_options - for type_uris in type_uri_options - for local_id in local_id_options - ] +data = [(uris, type_uris, local_id) + for uris in server_url_options for type_uris in type_uri_options + for local_id in local_id_options] + class OpenIDYadisTest(unittest.TestCase): def __init__(self, uris, type_uris, local_id): @@ -115,15 +118,14 @@ class OpenIDYadisTest(unittest.TestCase): self.yadis_url = 'http://unit.test/' # Create an XRDS document to parse - services = mkService(uris=self.uris, - type_uris=self.type_uris, - local_id=self.local_id) + services = mkService( + uris=self.uris, type_uris=self.type_uris, local_id=self.local_id) self.xrds = mkXRDS(services) def runTest(self): # Parse into endpoint objects that we will check - endpoints = applyFilter( - self.yadis_url, self.xrds, OpenIDServiceEndpoint) + endpoints = applyFilter(self.yadis_url, self.xrds, + OpenIDServiceEndpoint) # make sure there are the same number of endpoints as # URIs. This assumes that the type_uris contains at least one @@ -158,6 +160,7 @@ class OpenIDYadisTest(unittest.TestCase): # Make sure we saw all URIs, and saw each one once self.assertEqual(uris, seen_uris) + def pyUnitTests(): cases = [] for args in data: diff --git a/openid/test/test_pape.py b/openid/test/test_pape.py index 044b9e5db141ade4e88a32b0f50e11840ca524e0..60916671c9a426469db8bb085306e43075ba055a 100644 --- a/openid/test/test_pape.py +++ b/openid/test/test_pape.py @@ -1,4 +1,3 @@ - from openid.extensions import pape import unittest diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index 5275e92c3eec1ea8a71397515fb736a7e21522b5..89ce9984d985734757d237a99d1d7f1db46834ca 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -1,10 +1,10 @@ - from openid.extensions.draft import pape2 as pape from openid.message import * from openid.server import server import unittest + class PapeRequestTestCase(unittest.TestCase): def setUp(self): self.req = pape.Request() @@ -15,37 +15,54 @@ class PapeRequestTestCase(unittest.TestCase): self.assertEqual('pape', self.req.ns_alias) req2 = pape.Request([pape.AUTH_MULTI_FACTOR], 1000) - self.assertEqual([pape.AUTH_MULTI_FACTOR], req2.preferred_auth_policies) + self.assertEqual([pape.AUTH_MULTI_FACTOR], + req2.preferred_auth_policies) self.assertEqual(1000, req2.max_auth_age) def test_add_policy_uri(self): self.assertEqual([], self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR], self.req.preferred_auth_policies) + self.assertEqual([pape.AUTH_MULTI_FACTOR], + self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR], self.req.preferred_auth_policies) + self.assertEqual([pape.AUTH_MULTI_FACTOR], + self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + self.req.preferred_auth_policies) def test_getExtensionArgs(self): - self.assertEqual({'preferred_auth_policies': ''}, self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': '' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://uri') - self.assertEqual({'preferred_auth_policies': 'http://uri'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://zig') - self.assertEqual({'preferred_auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri http://zig' + }, self.req.getExtensionArgs()) self.req.max_auth_age = 789 - self.assertEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri http://zig', + 'max_auth_age': '789' + }, self.req.getExtensionArgs()) def test_parseExtensionArgs(self): - args = {'preferred_auth_policies': 'http://foo http://bar', - 'max_auth_age': '9'} + args = { + 'preferred_auth_policies': 'http://foo http://bar', + 'max_auth_age': '9' + } self.req.parseExtensionArgs(args) self.assertEqual(9, self.req.max_auth_age) - self.assertEqual(['http://foo','http://bar'], self.req.preferred_auth_policies) + self.assertEqual(['http://foo', 'http://bar'], + self.req.preferred_auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) @@ -54,16 +71,23 @@ class PapeRequestTestCase(unittest.TestCase): def test_fromOpenIDRequest(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.max_auth_age': '5476' - }) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.preferred_auth_policies': + ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.max_auth_age': + '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], req.preferred_auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + req.preferred_auth_policies) self.assertEqual(5476, req.max_auth_age) def test_fromOpenIDRequest_no_pape(self): @@ -71,15 +95,16 @@ class PapeRequestTestCase(unittest.TestCase): openid_req = server.OpenIDRequest() openid_req.message = message pape_req = pape.Request.fromOpenIDRequest(openid_req) - assert(pape_req is None) + assert (pape_req is None) def test_preferred_types(self): self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, - pape.AUTH_MULTI_FACTOR_PHYSICAL]) + pt = self.req.preferredTypes( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.assertEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -88,6 +113,7 @@ class DummySuccessResponse: def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.req = pape.Response() @@ -98,7 +124,8 @@ class PapeResponseTestCase(unittest.TestCase): self.assertEqual('pape', self.req.ns_alias) self.assertEqual(None, self.req.nist_auth_level) - req2 = pape.Response([pape.AUTH_MULTI_FACTOR], "2004-12-11T10:30:44Z", 3) + req2 = pape.Response([pape.AUTH_MULTI_FACTOR], "2004-12-11T10:30:44Z", + 3) self.assertEqual([pape.AUTH_MULTI_FACTOR], req2.auth_policies) self.assertEqual("2004-12-11T10:30:44Z", req2.auth_time) self.assertEqual(3, req2.nist_auth_level) @@ -110,20 +137,37 @@ class PapeResponseTestCase(unittest.TestCase): self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) self.assertEqual([pape.AUTH_MULTI_FACTOR], self.req.auth_policies) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], self.req.auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + self.req.auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], self.req.auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + self.req.auth_policies) def test_getExtensionArgs(self): - self.assertEqual({'auth_policies': 'none'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'none' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://uri') - self.assertEqual({'auth_policies': 'http://uri'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://zig') - self.assertEqual({'auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig' + }, self.req.getExtensionArgs()) self.req.auth_time = "1776-07-04T14:43:12Z" - self.assertEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, self.req.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z" + }, self.req.getExtensionArgs()) self.req.nist_auth_level = 3 - self.assertEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", 'nist_auth_level': '3'}, self.req.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z", + 'nist_auth_level': '3' + }, self.req.getExtensionArgs()) def test_getExtensionArgs_error_auth_age(self): self.req.auth_time = "long ago" @@ -138,73 +182,95 @@ class PapeResponseTestCase(unittest.TestCase): self.assertRaises(ValueError, self.req.getExtensionArgs) def test_parseExtensionArgs(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z'} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z' + } self.req.parseExtensionArgs(args) self.assertEqual('1970-01-01T00:00:00Z', self.req.auth_time) - self.assertEqual(['http://foo','http://bar'], self.req.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.req.auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) self.assertEqual(None, self.req.auth_time) self.assertEqual([], self.req.auth_policies) - + def test_parseExtensionArgs_strict_bogus1(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'yesterday'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, - args, True) + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': 'yesterday' + } + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, True) def test_parseExtensionArgs_strict_bogus2(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z', - 'nist_auth_level': 'some'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, - args, True) - + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z', + 'nist_auth_level': 'some' + } + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, True) + def test_parseExtensionArgs_strict_good(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z', - 'nist_auth_level': '0'} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z', + 'nist_auth_level': '0' + } self.req.parseExtensionArgs(args, True) - self.assertEqual(['http://foo','http://bar'], self.req.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.assertEqual('1970-01-01T00:00:00Z', self.req.auth_time) self.assertEqual(0, self.req.nist_auth_level) def test_parseExtensionArgs_nostrict_bogus(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'when the cows come home', - 'nist_auth_level': 'some'} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': 'when the cows come home', + 'nist_auth_level': 'some' + } self.req.parseExtensionArgs(args) - self.assertEqual(['http://foo','http://bar'], self.req.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.assertEqual(None, self.req.auth_time) self.assertEqual(None, self.req.nist_auth_level) def test_fromSuccessResponse(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': + 'id_res', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.auth_policies': + ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': + '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': + ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'auth_time': + '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) - self.assertEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], req.auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], + req.auth_policies) self.assertEqual('1970-01-01T00:00:00Z', req.auth_time) def test_fromSuccessResponseNoSignedArgs(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': + 'id_res', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.auth_policies': + ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': + '1970-01-01T00:00:00Z' + }) signed_stuff = {} diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index a4026e8fe206d8497a2b12ce3d5f5b5a30ba225c..0d3905c23eb526cd4fee0226a1138ba09fddd082 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,11 +1,10 @@ - from openid.extensions.draft import pape5 as pape from openid.message import * from openid.server import server import warnings -warnings.filterwarnings('ignore', module=__name__, - message='"none" used as a policy URI') +warnings.filterwarnings( + 'ignore', module=__name__, message='"none" used as a policy URI') import unittest @@ -21,35 +20,33 @@ class PapeRequestTestCase(unittest.TestCase): self.assertFalse(self.req.preferred_auth_level_types) bogus_levels = ['http://janrain.com/our_levels'] - req2 = pape.Request( - [pape.AUTH_MULTI_FACTOR], 1000, bogus_levels) + req2 = pape.Request([pape.AUTH_MULTI_FACTOR], 1000, bogus_levels) self.assertEqual([pape.AUTH_MULTI_FACTOR], - req2.preferred_auth_policies) + req2.preferred_auth_policies) self.assertEqual(1000, req2.max_auth_age) self.assertEqual(bogus_levels, req2.preferred_auth_level_types) def test_addAuthLevel(self): self.req.addAuthLevel('http://example.com/', 'example') self.assertEqual(['http://example.com/'], - self.req.preferred_auth_level_types) + self.req.preferred_auth_level_types) self.assertEqual('http://example.com/', - self.req.auth_level_aliases['example']) + self.req.auth_level_aliases['example']) self.req.addAuthLevel('http://example.com/1', 'example1') self.assertEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.req.preferred_auth_level_types) self.req.addAuthLevel('http://example.com/', 'exmpl') self.assertEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.req.preferred_auth_level_types) self.req.addAuthLevel('http://example.com/', 'example') self.assertEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.req.preferred_auth_level_types) - self.assertRaises(KeyError, - self.req.addAuthLevel, - 'http://example.com/2', 'example') + self.assertRaises(KeyError, self.req.addAuthLevel, + 'http://example.com/2', 'example') # alias is None; we expect a new one to be generated. uri = 'http://another.example.com/' @@ -67,35 +64,36 @@ class PapeRequestTestCase(unittest.TestCase): self.assertEqual([], self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) self.assertEqual([pape.AUTH_MULTI_FACTOR], - self.req.preferred_auth_policies) + self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) self.assertEqual([pape.AUTH_MULTI_FACTOR], - self.req.preferred_auth_policies) + self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, + pape.AUTH_PHISHING_RESISTANT], self.req.preferred_auth_policies) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, + pape.AUTH_PHISHING_RESISTANT], self.req.preferred_auth_policies) def test_getExtensionArgs(self): - self.assertEqual({'preferred_auth_policies': ''}, - self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': '' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://uri') - self.assertEqual( - {'preferred_auth_policies': 'http://uri'}, - self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri' + }, self.req.getExtensionArgs()) self.req.addPolicyURI('http://zig') - self.assertEqual( - {'preferred_auth_policies': 'http://uri http://zig'}, - self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri http://zig' + }, self.req.getExtensionArgs()) self.req.max_auth_age = 789 - self.assertEqual( - {'preferred_auth_policies': 'http://uri http://zig', - 'max_auth_age': '789'}, - self.req.getExtensionArgs()) + self.assertEqual({ + 'preferred_auth_policies': 'http://uri http://zig', + 'max_auth_age': '789' + }, self.req.getExtensionArgs()) def test_getExtensionArgsWithAuthLevels(self): uri = 'http://example.com/auth_level' @@ -111,7 +109,7 @@ class PapeRequestTestCase(unittest.TestCase): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } self.assertEqual(expected_args, self.req.getExtensionArgs()) @@ -127,10 +125,11 @@ class PapeRequestTestCase(unittest.TestCase): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } # Check request object state - self.req.parseExtensionArgs(request_args, is_openid1=False, strict=False) + self.req.parseExtensionArgs( + request_args, is_openid1=False, strict=False) expected_auth_levels = [uri, uri2] @@ -141,8 +140,8 @@ class PapeRequestTestCase(unittest.TestCase): def test_parseExtensionArgsWithAuthLevels_openID1(self): request_args = { - 'preferred_auth_level_types':'nist jisa', - } + 'preferred_auth_level_types': 'nist jisa', + } expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] self.req.parseExtensionArgs(request_args, is_openid1=True) self.assertEqual(expected_auth_levels, @@ -150,37 +149,49 @@ class PapeRequestTestCase(unittest.TestCase): self.req = pape.Request() self.req.parseExtensionArgs(request_args, is_openid1=False) - self.assertEqual([], - self.req.preferred_auth_level_types) + self.assertEqual([], self.req.preferred_auth_level_types) self.req = pape.Request() - self.assertRaises(ValueError, - self.req.parseExtensionArgs, - request_args, is_openid1=False, strict=True) + self.assertRaises( + ValueError, + self.req.parseExtensionArgs, + request_args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_ignoreBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} + request_args = {'preferred_auth_level_types': 'monkeys'} self.req.parseExtensionArgs(request_args, False) self.assertEqual([], self.req.preferred_auth_level_types) def test_parseExtensionArgs_strictBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, - request_args, is_openid1=False, strict=True) + request_args = {'preferred_auth_level_types': 'monkeys'} + self.assertRaises( + ValueError, + self.req.parseExtensionArgs, + request_args, + is_openid1=False, + strict=True) def test_parseExtensionArgs(self): - args = {'preferred_auth_policies': 'http://foo http://bar', - 'max_auth_age': '9'} + args = { + 'preferred_auth_policies': 'http://foo http://bar', + 'max_auth_age': '9' + } self.req.parseExtensionArgs(args, False) self.assertEqual(9, self.req.max_auth_age) - self.assertEqual(['http://foo','http://bar'], - self.req.preferred_auth_policies) + self.assertEqual(['http://foo', 'http://bar'], + self.req.preferred_auth_policies) self.assertEqual([], self.req.preferred_auth_level_types) def test_parseExtensionArgs_strict_bad_auth_age(self): args = {'max_auth_age': 'not an int'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, args, - is_openid1=False, strict=True) + self.assertRaises( + ValueError, + self.req.parseExtensionArgs, + args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}, False) @@ -191,12 +202,17 @@ class PapeRequestTestCase(unittest.TestCase): def test_fromOpenIDRequest(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join(policy_uris), - 'pape.max_auth_age': '5476' - }) + 'mode': + 'checkid_setup', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.preferred_auth_policies': + ' '.join(policy_uris), + 'pape.max_auth_age': + '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) @@ -208,15 +224,16 @@ class PapeRequestTestCase(unittest.TestCase): openid_req = server.OpenIDRequest() openid_req.message = message pape_req = pape.Request.fromOpenIDRequest(openid_req) - assert(pape_req is None) + assert (pape_req is None) def test_preferred_types(self): self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, - pape.AUTH_MULTI_FACTOR_PHYSICAL]) + pt = self.req.preferredTypes( + [pape.AUTH_MULTI_FACTOR, pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.assertEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -228,6 +245,7 @@ class DummySuccessResponse: def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.resp = pape.Response() @@ -238,8 +256,8 @@ class PapeResponseTestCase(unittest.TestCase): self.assertEqual('pape', self.resp.ns_alias) self.assertEqual(None, self.resp.nist_auth_level) - req2 = pape.Response([pape.AUTH_MULTI_FACTOR], - "2004-12-11T10:30:44Z", {pape.LEVELS_NIST: 3}) + req2 = pape.Response([pape.AUTH_MULTI_FACTOR], "2004-12-11T10:30:44Z", + {pape.LEVELS_NIST: 3}) self.assertEqual([pape.AUTH_MULTI_FACTOR], req2.auth_policies) self.assertEqual("2004-12-11T10:30:44Z", req2.auth_time) self.assertEqual(3, req2.nist_auth_level) @@ -251,50 +269,53 @@ class PapeResponseTestCase(unittest.TestCase): self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) self.assertEqual([pape.AUTH_MULTI_FACTOR], self.resp.auth_policies) self.resp.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.resp.auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, + pape.AUTH_PHISHING_RESISTANT], self.resp.auth_policies) self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.resp.auth_policies) + self.assertEqual( + [pape.AUTH_MULTI_FACTOR, + pape.AUTH_PHISHING_RESISTANT], self.resp.auth_policies) - self.assertRaises(RuntimeError, self.resp.addPolicyURI, - pape.AUTH_NONE) + self.assertRaises(RuntimeError, self.resp.addPolicyURI, pape.AUTH_NONE) def test_getExtensionArgs(self): - self.assertEqual({'auth_policies': pape.AUTH_NONE}, - self.resp.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': pape.AUTH_NONE + }, self.resp.getExtensionArgs()) self.resp.addPolicyURI('http://uri') - self.assertEqual({'auth_policies': 'http://uri'}, - self.resp.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri' + }, self.resp.getExtensionArgs()) self.resp.addPolicyURI('http://zig') - self.assertEqual({'auth_policies': 'http://uri http://zig'}, - self.resp.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig' + }, self.resp.getExtensionArgs()) self.resp.auth_time = "1776-07-04T14:43:12Z" - self.assertEqual( - {'auth_policies': 'http://uri http://zig', - 'auth_time': "1776-07-04T14:43:12Z"}, - self.resp.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z" + }, self.resp.getExtensionArgs()) self.resp.setAuthLevel(pape.LEVELS_NIST, '3') - self.assertEqual( - {'auth_policies': 'http://uri http://zig', - 'auth_time': "1776-07-04T14:43:12Z", - 'auth_level.nist': '3', - 'auth_level.ns.nist': pape.LEVELS_NIST}, - self.resp.getExtensionArgs()) + self.assertEqual({ + 'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z", + 'auth_level.nist': '3', + 'auth_level.ns.nist': pape.LEVELS_NIST + }, self.resp.getExtensionArgs()) def test_getExtensionArgs_error_auth_age(self): self.resp.auth_time = "long ago" self.assertRaises(ValueError, self.resp.getExtensionArgs) def test_parseExtensionArgs(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z'} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z' + } self.resp.parseExtensionArgs(args, is_openid1=False) self.assertEqual('1970-01-01T00:00:00Z', self.resp.auth_time) - self.assertEqual(['http://foo','http://bar'], - self.resp.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.resp.auth_policies) def test_parseExtensionArgs_valid_none(self): args = {'auth_policies': pape.AUTH_NONE} @@ -310,7 +331,10 @@ class PapeResponseTestCase(unittest.TestCase): args = {'auth_policies': 'none'} self.assertRaises( ValueError, - self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + self.resp.parseExtensionArgs, + args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_empty(self): self.resp.parseExtensionArgs({}, is_openid1=False) @@ -320,14 +344,16 @@ class PapeResponseTestCase(unittest.TestCase): def test_parseExtensionArgs_empty_strict(self): self.assertRaises( ValueError, - self.resp.parseExtensionArgs, {}, is_openid1=False, strict=True) + self.resp.parseExtensionArgs, {}, + is_openid1=False, + strict=True) def test_parseExtensionArgs_ignore_superfluous_none(self): policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] args = { 'auth_policies': ' '.join(policies), - } + } self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) @@ -339,21 +365,32 @@ class PapeResponseTestCase(unittest.TestCase): args = { 'auth_policies': ' '.join(policies), - } + } - self.assertRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + self.assertRaises( + ValueError, + self.resp.parseExtensionArgs, + args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_strict_bogus1(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'yesterday'} - self.assertRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': 'yesterday' + } + self.assertRaises( + ValueError, + self.resp.parseExtensionArgs, + args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_openid1_strict(self): - args = {'auth_level.nist': '0', - 'auth_policies': pape.AUTH_NONE, - } + args = { + 'auth_level.nist': '0', + 'auth_policies': pape.AUTH_NONE, + } self.resp.parseExtensionArgs(args, strict=True, is_openid1=True) self.assertEqual('0', self.resp.getAuthLevel(pape.LEVELS_NIST)) self.assertEqual([], self.resp.auth_policies) @@ -361,57 +398,69 @@ class PapeResponseTestCase(unittest.TestCase): def test_parseExtensionArgs_strict_no_namespace_decl_openid2(self): # Test the case where the namespace is not declared for an # auth level. - args = {'auth_policies': pape.AUTH_NONE, - 'auth_level.nist': '0', - } - self.assertRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + args = { + 'auth_policies': pape.AUTH_NONE, + 'auth_level.nist': '0', + } + self.assertRaises( + ValueError, + self.resp.parseExtensionArgs, + args, + is_openid1=False, + strict=True) def test_parseExtensionArgs_nostrict_no_namespace_decl_openid2(self): # Test the case where the namespace is not declared for an # auth level. - args = {'auth_policies': pape.AUTH_NONE, - 'auth_level.nist': '0', - } + args = { + 'auth_policies': pape.AUTH_NONE, + 'auth_level.nist': '0', + } self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) # There is no namespace declaration for this auth level. - self.assertRaises(KeyError, self.resp.getAuthLevel, - pape.LEVELS_NIST) + self.assertRaises(KeyError, self.resp.getAuthLevel, pape.LEVELS_NIST) def test_parseExtensionArgs_strict_good(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z', - 'auth_level.nist': '0', - 'auth_level.ns.nist': pape.LEVELS_NIST} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z', + 'auth_level.nist': '0', + 'auth_level.ns.nist': pape.LEVELS_NIST + } self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) - self.assertEqual(['http://foo','http://bar'], - self.resp.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.assertEqual('1970-01-01T00:00:00Z', self.resp.auth_time) self.assertEqual(0, self.resp.nist_auth_level) def test_parseExtensionArgs_nostrict_bogus(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'when the cows come home', - 'nist_auth_level': 'some'} + args = { + 'auth_policies': 'http://foo http://bar', + 'auth_time': 'when the cows come home', + 'nist_auth_level': 'some' + } self.resp.parseExtensionArgs(args, is_openid1=False) - self.assertEqual(['http://foo','http://bar'], - self.resp.auth_policies) + self.assertEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.assertEqual(None, self.resp.auth_time) self.assertEqual(None, self.resp.nist_auth_level) def test_fromSuccessResponse(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': + 'id_res', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.auth_policies': + ' '.join(policy_uris), + 'pape.auth_time': + '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join(policy_uris), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': ' '.join(policy_uris), + 'auth_time': '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) @@ -421,12 +470,17 @@ class PapeResponseTestCase(unittest.TestCase): def test_fromSuccessResponseNoSignedArgs(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': + 'id_res', + 'ns': + OPENID2_NS, + 'ns.pape': + pape.ns_uri, + 'pape.auth_policies': + ' '.join(policy_uris), + 'pape.auth_time': + '1970-01-01T00:00:00Z' + }) signed_stuff = {} @@ -438,5 +492,6 @@ class PapeResponseTestCase(unittest.TestCase): resp = pape.Response.fromSuccessResponse(oid_req) self.assertTrue(resp is None) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index 387ddc55a0a2374627ebc18f6d9d68f766f4094c..fdeecd04356d2d5f2ee94c1bff1e6850b1397b37 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -36,10 +36,8 @@ class _TestCase(unittest.TestCase): self.assertTrue(self.expected == 'EOF', (self.case, self.expected)) def shortDescription(self): - return "%s (%s<%s>)" % ( - self.testname, - self.__class__.__module__, - os.path.basename(self.filename)) + return "%s (%s<%s>)" % (self.testname, self.__class__.__module__, + os.path.basename(self.filename)) def parseCases(data): @@ -62,6 +60,7 @@ def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + filenames = ['data/test1-parsehtml.txt'] default_test_files = [] diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index bd877ab2d60e939c43bd79ef5d896ae8968ba871..a8b089039f9bcc06c324d8a7ffd9ab8a58bce363 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -9,6 +9,7 @@ from openid.server import trustroot from openid.test.support import CatchLogs import unittest + # Too many methods does not apply to unit test objects #pylint:disable-msg=R0904 class TestBuildDiscoveryURL(unittest.TestCase): @@ -36,6 +37,7 @@ class TestBuildDiscoveryURL(unittest.TestCase): self.failUnlessDiscoURL('http://*.example.com/foo', 'http://www.example.com/foo') + class TestExtractReturnToURLs(unittest.TestCase): disco_url = 'http://example.com/' @@ -54,20 +56,20 @@ class TestExtractReturnToURLs(unittest.TestCase): return result def failUnlessFileHasReturnURLs(self, filename, expected_return_urls): - self.failUnlessXRDSHasReturnURLs(open(filename).read(), - expected_return_urls) + self.failUnlessXRDSHasReturnURLs( + open(filename).read(), expected_return_urls) def failUnlessXRDSHasReturnURLs(self, data, expected_return_urls): self.data = data - actual_return_urls = list(trustroot.getAllowedReturnURLs( - self.disco_url)) + actual_return_urls = list( + trustroot.getAllowedReturnURLs(self.disco_url)) self.assertEqual(expected_return_urls, actual_return_urls) def failUnlessDiscoveryFailure(self, text): self.data = text - self.assertRaises( - DiscoveryFailure, trustroot.getAllowedReturnURLs, self.disco_url) + self.assertRaises(DiscoveryFailure, trustroot.getAllowedReturnURLs, + self.disco_url) def test_empty(self): self.failUnlessDiscoveryFailure('') @@ -140,8 +142,7 @@ class TestExtractReturnToURLs(unittest.TestCase): </Service> </XRD> </xrds:XRDS> -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) def test_twoEntries_withOther(self): self.failUnlessXRDSHasReturnURLs(b'''\ @@ -164,9 +165,7 @@ class TestExtractReturnToURLs(unittest.TestCase): </Service> </XRD> </xrds:XRDS> -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) - +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) class TestReturnToMatches(unittest.TestCase): @@ -179,31 +178,30 @@ class TestReturnToMatches(unittest.TestCase): def test_garbageMatch(self): r = 'http://example.com/return.to' - self.assertTrue(trustroot.returnToMatches( - ['This is not a URL at all. In fact, it has characters, ' - 'like "<" that are not allowed in URLs', - r], - r)) + self.assertTrue( + trustroot.returnToMatches([ + 'This is not a URL at all. In fact, it has characters, ' + 'like "<" that are not allowed in URLs', r + ], r)) def test_descendant(self): r = 'http://example.com/return.to' - self.assertTrue(trustroot.returnToMatches( - [r], - 'http://example.com/return.to/user:joe')) + self.assertTrue( + trustroot.returnToMatches([r], + 'http://example.com/return.to/user:joe')) def test_wildcard(self): - self.assertFalse(trustroot.returnToMatches( - ['http://*.example.com/return.to'], - 'http://example.com/return.to')) + self.assertFalse( + trustroot.returnToMatches(['http://*.example.com/return.to'], + 'http://example.com/return.to')) def test_noMatch(self): r = 'http://example.com/return.to' - self.assertFalse(trustroot.returnToMatches( - [r], - 'http://example.com/xss_exploit')) + self.assertFalse( + trustroot.returnToMatches([r], 'http://example.com/xss_exploit')) -class TestVerifyReturnTo(unittest.TestCase, CatchLogs): +class TestVerifyReturnTo(unittest.TestCase, CatchLogs): def setUp(self): CatchLogs.setUp(self) @@ -221,8 +219,7 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): self.assertEqual('http://www.example.com/', disco_url) return [return_to] - self.assertTrue( - trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogEmpty() def test_verifyFailWithDiscoveryCalled(self): @@ -249,5 +246,6 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogMatches("Attempting to verify") + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 0f670ac7d909c86007afc291cd1e19839cdf83af..21a7c6f7eba8d56e86a9bcf9231a864c1f1f6ff6 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -27,16 +27,19 @@ class TestProtocolError(unittest.TestCase): return_to = "http://rp.unittest/consumer" # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ - 'openid.mode': 'monkeydance', - 'openid.identity': 'http://wagu.unittest/', - 'openid.return_to': return_to, - }) + 'openid.mode': + 'monkeydance', + 'openid.identity': + 'http://wagu.unittest/', + 'openid.return_to': + return_to, + }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = parse_qs(result_args) @@ -46,19 +49,24 @@ class TestProtocolError(unittest.TestCase): return_to = "http://rp.unittest/consumer" # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ - 'openid.ns': OPENID2_NS, - 'openid.mode': 'monkeydance', - 'openid.identity': 'http://wagu.unittest/', - 'openid.claimed_id': 'http://wagu.unittest/', - 'openid.return_to': return_to, - }) + 'openid.ns': + OPENID2_NS, + 'openid.mode': + 'monkeydance', + 'openid.identity': + 'http://wagu.unittest/', + 'openid.claimed_id': + 'http://wagu.unittest/', + 'openid.return_to': + return_to, + }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = parse_qs(result_args) @@ -68,19 +76,24 @@ class TestProtocolError(unittest.TestCase): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ - 'openid.ns': OPENID2_NS, - 'openid.mode': 'monkeydance', - 'openid.identity': 'http://wagu.unittest/', - 'openid.claimed_id': 'http://wagu.unittest/', - 'openid.return_to': return_to, - }) + 'openid.ns': + OPENID2_NS, + 'openid.mode': + 'monkeydance', + 'openid.identity': + 'http://wagu.unittest/', + 'openid.claimed_id': + 'http://wagu.unittest/', + 'openid.return_to': + return_to, + }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } self.assertTrue(e.whichEncoding() == server.ENCODE_HTML_FORM) self.assertTrue(e.toFormMarkup() == e.toMessage().toFormMarkup( @@ -90,16 +103,19 @@ class TestProtocolError(unittest.TestCase): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ - 'openid.mode': 'monkeydance', - 'openid.identity': 'http://wagu.unittest/', - 'openid.return_to': return_to, - }) + 'openid.mode': + 'monkeydance', + 'openid.identity': + 'http://wagu.unittest/', + 'openid.return_to': + return_to, + }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } self.assertTrue(e.whichEncoding() == server.ENCODE_URL) @@ -112,7 +128,7 @@ class TestProtocolError(unittest.TestCase): args = Message.fromPostArgs({ 'openid.mode': 'zebradance', 'openid.identity': 'http://wagu.unittest/', - }) + }) e = server.ProtocolError(args, "waffles") self.assertFalse(e.hasReturnTo()) expected = b"""error:waffles @@ -148,14 +164,14 @@ class TestDecode(unittest.TestCase): args = { 'pony': 'spotted', 'sreg.mutant_power': 'decaffinator', - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_bad(self): args = { 'openid.mode': 'twos-compliment', 'openid.pants': 'zippered', - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_dictOfLists(self): @@ -165,13 +181,13 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } try: result = self.decode(args) except TypeError as err: self.assertTrue(str(err).find('values') != -1, err) else: - self.fail("Expected TypeError, but got result %s" % (result,)) + self.fail("Expected TypeError, but got result %s" % (result, )) def test_checkidImmediate(self): args = { @@ -182,7 +198,7 @@ class TestDecode(unittest.TestCase): 'openid.trust_root': self.tr_url, # should be ignored 'openid.some.extension': 'junk', - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckIDRequest)) self.assertEqual(r.mode, "checkid_immediate") @@ -199,7 +215,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckIDRequest)) self.assertEqual(r.mode, "checkid_setup") @@ -217,7 +233,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckIDRequest)) self.assertEqual(r.mode, "checkid_setup") @@ -235,7 +251,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoIdentityOpenID2(self): @@ -245,7 +261,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckIDRequest)) self.assertEqual(r.mode, "checkid_setup") @@ -263,7 +279,7 @@ class TestDecode(unittest.TestCase): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.trust_root': self.tr_url, - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID2(self): @@ -278,7 +294,7 @@ class TestDecode(unittest.TestCase): 'openid.claimed_id': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.realm': self.tr_url, - } + } self.assertTrue(isinstance(self.decode(args), server.CheckIDRequest)) req = self.decode(args) @@ -296,7 +312,7 @@ class TestDecode(unittest.TestCase): 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupBadReturn(self): @@ -305,14 +321,14 @@ class TestDecode(unittest.TestCase): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': 'not a url', - } + } try: result = self.decode(args) except server.ProtocolError as err: self.assertTrue(err.openid_message) else: self.fail("Expected ProtocolError, instead returned with %s" % - (result,)) + (result, )) def test_checkidSetupUntrustedReturn(self): args = { @@ -321,14 +337,14 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': 'http://not-the-return-place.unittest/', - } + } try: result = self.decode(args) except server.UntrustedReturnURL as err: self.assertTrue(err.openid_message) else: self.fail("Expected UntrustedReturnURL, instead returned with %s" % - (result,)) + (result, )) def test_checkAuth(self): args = { @@ -340,7 +356,7 @@ class TestDecode(unittest.TestCase): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckAuthRequest)) self.assertEqual(r.mode, 'check_authentication') @@ -354,7 +370,7 @@ class TestDecode(unittest.TestCase): 'openid.foo': 'signedval1', 'openid.bar': 'signedval2', 'openid.baz': 'unsigned', - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkAuthAndInvalidate(self): @@ -368,7 +384,7 @@ class TestDecode(unittest.TestCase): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.CheckAuthRequest)) self.assertEqual(r.invalidate_handle, '[[SMART_handle]]') @@ -378,7 +394,7 @@ class TestDecode(unittest.TestCase): 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.AssociateRequest)) self.assertEqual(r.mode, "associate") @@ -391,7 +407,7 @@ class TestDecode(unittest.TestCase): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', - } + } # Using DH-SHA1 without supplying dh_consumer_public is an error. self.assertRaises(server.ProtocolError, self.decode, args) @@ -400,7 +416,7 @@ class TestDecode(unittest.TestCase): 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "donkeydonkeydonkey", - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHModGen(self): @@ -411,7 +427,7 @@ class TestDecode(unittest.TestCase): 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': cryptutil.longToBase64(ALT_MODULUS), 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN) - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.AssociateRequest)) self.assertEqual(r.mode, "associate") @@ -429,7 +445,7 @@ class TestDecode(unittest.TestCase): 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', 'openid.dh_gen': 'gnocchi', - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHMissingModGen(self): @@ -439,7 +455,7 @@ class TestDecode(unittest.TestCase): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', - } + } self.assertRaises(server.ProtocolError, self.decode, args) # def test_associateDHInvalidModGen(self): @@ -460,13 +476,13 @@ class TestDecode(unittest.TestCase): 'openid.mode': 'associate', 'openid.session_type': 'FLCL6', 'openid.dh_consumer_public': "YQ==\n", - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_associatePlain(self): args = { 'openid.mode': 'associate', - } + } r = self.decode(args) self.assertTrue(isinstance(r, server.AssociateRequest)) self.assertEqual(r.mode, "associate") @@ -477,12 +493,11 @@ class TestDecode(unittest.TestCase): args = { 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "my public keeey", - } + } self.assertRaises(server.ProtocolError, self.decode, args) def test_invalidns(self): - args = {'openid.ns': 'Tuesday', - 'openid.mode': 'associate'} + args = {'openid.ns': 'Tuesday', 'openid.mode': 'associate'} try: r = self.decode(args) @@ -493,11 +508,10 @@ class TestDecode(unittest.TestCase): # The error message contains the bad openid.ns. self.assertTrue('Tuesday' in str(err), str(err)) else: - self.fail("Expected ProtocolError but returned with %r" % (r,)) + self.fail("Expected ProtocolError but returned with %r" % (r, )) class TestEncode(unittest.TestCase): - def setUp(self): self.encoder = server.Encoder() self.encode = self.encoder.encode @@ -516,17 +530,21 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'ns': OPENID2_NS, - 'mode': 'id_res', - 'identity': request.identity, - 'claimed_id': request.identity, - 'return_to': request.return_to, - }) + 'ns': + OPENID2_NS, + 'mode': + 'id_res', + 'identity': + request.identity, + 'claimed_id': + request.identity, + 'return_to': + request.return_to, + }) self.assertFalse(response.renderAsForm()) self.assertTrue(response.whichEncoding() == server.ENCODE_URL) @@ -544,17 +562,21 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'ns': OPENID2_NS, - 'mode': 'id_res', - 'identity': request.identity, - 'claimed_id': request.identity, - 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + 'ns': + OPENID2_NS, + 'mode': + 'id_res', + 'identity': + request.identity, + 'claimed_id': + request.identity, + 'return_to': + 'x' * OPENID1_URL_LIMIT, + }) self.assertTrue(response.renderAsForm()) self.assertTrue(len(response.encodeToURL()) > OPENID1_URL_LIMIT) @@ -568,17 +590,21 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'ns': OPENID2_NS, - 'mode': 'id_res', - 'identity': request.identity, - 'claimed_id': request.identity, - 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + 'ns': + OPENID2_NS, + 'mode': + 'id_res', + 'identity': + request.identity, + 'claimed_id': + request.identity, + 'return_to': + 'x' * OPENID1_URL_LIMIT, + }) form_markup = response.toFormMarkup({'foo': 'bar'}) self.assertTrue(' foo="bar"' in form_markup) @@ -589,17 +615,21 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'ns': OPENID2_NS, - 'mode': 'id_res', - 'identity': request.identity, - 'claimed_id': request.identity, - 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + 'ns': + OPENID2_NS, + 'mode': + 'id_res', + 'identity': + request.identity, + 'claimed_id': + request.identity, + 'return_to': + 'x' * OPENID1_URL_LIMIT, + }) html = response.toHTML() self.assertTrue('<html>' in html) self.assertTrue('</html>' in html) @@ -619,22 +649,24 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'identity': request.identity, - 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + 'mode': + 'id_res', + 'identity': + request.identity, + 'return_to': + 'x' * OPENID1_URL_LIMIT, + }) self.assertFalse(response.renderAsForm()) self.assertTrue(len(response.encodeToURL()) > OPENID1_URL_LIMIT) self.assertTrue(response.whichEncoding() == server.ENCODE_URL) webresponse = self.encode(response) self.assertEqual(webresponse.headers['location'], - response.encodeToURL()) + response.encodeToURL()) def test_id_res(self): request = server.CheckIDRequest( @@ -642,23 +674,25 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'identity': request.identity, - 'return_to': request.return_to, - }) + 'mode': + 'id_res', + 'identity': + request.identity, + 'return_to': + request.return_to, + }) webresponse = self.encode(response) self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertTrue('location' in webresponse.headers) location = webresponse.headers['location'] - self.assertTrue(location.startswith(request.return_to), - "%s does not start with %s" % (location, - request.return_to)) + self.assertTrue( + location.startswith(request.return_to), + "%s does not start with %s" % (location, request.return_to)) # argh. q2 = dict(parse_qsl(urlparse(location)[4])) expected = response.fields.toPostArgs() @@ -670,13 +704,12 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) webresponse = self.encode(response) self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertTrue('location' in webresponse.headers) @@ -687,13 +720,12 @@ class TestEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) form = response.toFormMarkup() self.assertTrue(form) @@ -703,7 +735,9 @@ class TestEncode(unittest.TestCase): request = server.AssociateRequest.fromMessage(msg) response = server.OpenIDResponse(request) response.fields = Message.fromPostArgs( - {'openid.assoc_handle': "every-zig"}) + { + 'openid.assoc_handle': "every-zig" + }) webresponse = self.encode(response) body = """assoc_handle:every-zig """ @@ -712,14 +746,14 @@ class TestEncode(unittest.TestCase): self.assertEqual(webresponse.body, body) def test_checkauthReply(self): - request = server.CheckAuthRequest('a_sock_monkey', - 'siggggg', - []) + request = server.CheckAuthRequest('a_sock_monkey', 'siggggg', []) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ - 'is_valid': 'true', - 'invalidate_handle': 'xXxX:xXXx' - }) + 'is_valid': + 'true', + 'invalidate_handle': + 'xXxX:xXXx' + }) body = """invalidate_handle:xXxX:xXXx is_valid:true """ @@ -731,7 +765,7 @@ is_valid:true def test_unencodableError(self): args = Message.fromPostArgs({ 'openid.identity': 'http://limu.unittest/', - }) + }) e = server.ProtocolError(args, "wet paint") self.assertRaises(server.EncodingError, self.encode, e) @@ -739,7 +773,7 @@ is_valid:true args = Message.fromPostArgs({ 'openid.mode': 'associate', 'openid.identity': 'http://limu.unittest/', - }) + }) body = "error:snoot\nmode:error\n" webresponse = self.encode(server.ProtocolError(args, "snoot")) self.assertEqual(webresponse.code, server.HTTP_ERROR) @@ -758,15 +792,17 @@ class TestSigningEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'identity': self.request.identity, - 'return_to': self.request.return_to, - }) + 'mode': + 'id_res', + 'identity': + self.request.identity, + 'return_to': + self.request.return_to, + }) self.signatory = server.Signatory(self.store) self.encoder = server.SigningEncoder(self.signatory) self.encode = self.encoder.encode @@ -775,8 +811,8 @@ class TestSigningEncode(unittest.TestCase): assoc_handle = '{bicycle}{shed}' self.store.storeAssociation( self._normal_key, - association.Association.fromExpiresIn(60, assoc_handle, - 'sekrit', 'HMAC-SHA1')) + association.Association.fromExpiresIn(60, assoc_handle, 'sekrit', + 'HMAC-SHA1')) self.request.assoc_handle = assoc_handle webresponse = self.encode(self.response) self.assertEqual(webresponse.code, server.HTTP_REDIRECT) @@ -809,8 +845,7 @@ class TestSigningEncode(unittest.TestCase): trust_root='http://burr.unittest/', return_to='http://burr.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields.setArg(OPENID_NS, 'mode', 'cancel') @@ -849,8 +884,7 @@ class TestCheckID(unittest.TestCase): trust_root='http://bar.unittest/', return_to='http://bar.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) self.request.message = Message(OPENID2_NS) def test_trustRootInvalid(self): @@ -873,8 +907,8 @@ class TestCheckID(unittest.TestCase): except server.MalformedTrustRoot as why: self.assertTrue(sentinel is why.openid_message) else: - self.fail('Expected MalformedTrustRoot exception. Got %r' - % (result,)) + self.fail('Expected MalformedTrustRoot exception. Got %r' % + (result, )) def test_trustRootValidNoReturnTo(self): request = server.CheckIDRequest( @@ -882,8 +916,7 @@ class TestCheckID(unittest.TestCase): trust_root='http://bar.unittest/', return_to=None, immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) self.assertTrue(request.trustRootValid()) @@ -891,6 +924,7 @@ class TestCheckID(unittest.TestCase): """Make sure that verifyReturnTo is calling the trustroot function verifyReturnTo """ + def withVerifyReturnTo(new_verify, callable): old_verify = server.verifyReturnTo try: @@ -918,20 +952,21 @@ class TestCheckID(unittest.TestCase): self.assertEqual(self.request.trust_root, trust_root) self.assertEqual(self.request.return_to, return_to) return val + return verify for val in [True, False]: - self.assertEqual( - val, - withVerifyReturnTo(constVerify(val), - self.request.returnToVerified)) + self.assertEqual(val, + withVerifyReturnTo( + constVerify(val), + self.request.returnToVerified)) def _expectAnswer(self, answer, identity=None, claimed_id=None): expected_list = [ ('mode', 'id_res'), ('return_to', self.request.return_to), ('op_endpoint', self.op_endpoint), - ] + ] if identity: expected_list.append(('identity', identity)) if claimed_id: @@ -942,16 +977,15 @@ class TestCheckID(unittest.TestCase): for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) self.assertEqual(actual, expected, - "%s: expected %s, got %s" % - (k, expected, actual)) + "%s: expected %s, got %s" % (k, expected, actual)) self.assertTrue(answer.fields.hasKey(OPENID_NS, 'response_nonce')) self.assertTrue(answer.fields.getOpenIDNamespace() == OPENID2_NS) # One for nonce, one for ns - self.assertEqual(len(answer.fields.toPostArgs()), - len(expected_list) + 2, - answer.fields.toPostArgs()) + self.assertEqual( + len(answer.fields.toPostArgs()), + len(expected_list) + 2, answer.fields.toPostArgs()) def test_answerAllow(self): """Check the fields specified by "Positive Assertions" @@ -985,8 +1019,7 @@ class TestCheckID(unittest.TestCase): def test_answerAllowAnonymousFail(self): self.request.identity = None # XXX - Check on this, I think this behavior is legal in OpenID 2.0? - self.assertRaises( - ValueError, self.request.answer, True, identity="=V") + self.assertRaises(ValueError, self.request.answer, True, identity="=V") def test_answerAllowWithIdentity(self): self.request.identity = IDENTIFIER_SELECT @@ -1001,8 +1034,8 @@ class TestCheckID(unittest.TestCase): self.request.identity = IDENTIFIER_SELECT selected_id = 'http://anon.unittest/9861' claimed_id = 'http://monkeyhat.unittest/' - answer = self.request.answer(True, identity=selected_id, - claimed_id=claimed_id) + answer = self.request.answer( + True, identity=selected_id, claimed_id=claimed_id) self._expectAnswer(answer, selected_id, claimed_id) def test_answerAllowWithDelegatedIdentityOpenID1(self): @@ -1013,15 +1046,20 @@ class TestCheckID(unittest.TestCase): self.request.identity = IDENTIFIER_SELECT selected_id = 'http://anon.unittest/9861' claimed_id = 'http://monkeyhat.unittest/' - self.assertRaises(server.VersionError, - self.request.answer, True, - identity=selected_id, - claimed_id=claimed_id) + self.assertRaises( + server.VersionError, + self.request.answer, + True, + identity=selected_id, + claimed_id=claimed_id) def test_answerAllowWithAnotherIdentity(self): # XXX - Check on this, I think this behavior is legal in OpenID 2.0? - self.assertRaises(ValueError, self.request.answer, True, - identity="http://pebbles.unittest/") + self.assertRaises( + ValueError, + self.request.answer, + True, + identity="http://pebbles.unittest/") def test_answerAllowWithIdentityNormalization(self): # The RP has sent us a non-normalized value for openid.identity, @@ -1037,14 +1075,13 @@ class TestCheckID(unittest.TestCase): # Expect the values that were sent in the request, even though # they're not normalized. - self._expectAnswer(answer, identity=non_normalized, - claimed_id=non_normalized) + self._expectAnswer( + answer, identity=non_normalized, claimed_id=non_normalized) def test_answerAllowNoIdentityOpenID1(self): self.request.message = Message(OPENID1_NS) self.request.identity = None - self.assertRaises(ValueError, self.request.answer, True, - identity=None) + self.assertRaises(ValueError, self.request.answer, True, identity=None) def test_answerAllowForgotEndpoint(self): self.request.op_endpoint = None @@ -1058,8 +1095,7 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'assoc_handle', 'bogus') self.assertRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + server.CheckIDRequest.fromMessage, msg, self.server) def test_fromMessageClaimedIDWithoutIdentityOpenID2(self): name = 'https://example.myopenid.com' @@ -1070,8 +1106,7 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'claimed_id', name) self.assertRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + server.CheckIDRequest.fromMessage, msg, self.server) def test_fromMessageIdentityWithoutClaimedIDOpenID2(self): name = 'https://example.myopenid.com' @@ -1082,8 +1117,7 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'identity', name) self.assertRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + server.CheckIDRequest.fromMessage, msg, self.server) def test_trustRootOpenID1(self): """Ignore openid.realm in OpenID 1""" @@ -1095,8 +1129,8 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'assoc_handle', 'bogus') msg.setArg(OPENID_NS, 'identity', 'george') - result = server.CheckIDRequest.fromMessage( - msg, self.server.op_endpoint) + result = server.CheckIDRequest.fromMessage(msg, + self.server.op_endpoint) self.assertTrue(result.trust_root == 'http://real_trust_root/') @@ -1111,8 +1145,8 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'identity', 'george') msg.setArg(OPENID_NS, 'claimed_id', 'george') - result = server.CheckIDRequest.fromMessage( - msg, self.server.op_endpoint) + result = server.CheckIDRequest.fromMessage(msg, + self.server.op_endpoint) self.assertTrue(result.trust_root == 'http://real_trust_root/') @@ -1138,13 +1172,21 @@ class TestCheckID(unittest.TestCase): def test_fromMessageWithEmptyTrustRoot(self): return_to = 'http://someplace.invalid/?go=thing' msg = Message.fromPostArgs({ - 'openid.assoc_handle': '{blah}{blah}{OZivdQ==}', - 'openid.claimed_id': 'http://delegated.invalid/', - 'openid.identity': 'http://op-local.example.com/', - 'openid.mode': 'checkid_setup', - 'openid.ns': 'http://openid.net/signon/1.0', - 'openid.return_to': return_to, - 'openid.trust_root': ''}) + 'openid.assoc_handle': + '{blah}{blah}{OZivdQ==}', + 'openid.claimed_id': + 'http://delegated.invalid/', + 'openid.identity': + 'http://op-local.example.com/', + 'openid.mode': + 'checkid_setup', + 'openid.ns': + 'http://openid.net/signon/1.0', + 'openid.return_to': + return_to, + 'openid.trust_root': + '' + }) result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) @@ -1159,8 +1201,8 @@ class TestCheckID(unittest.TestCase): msg.setArg(OPENID_NS, 'claimed_id', 'george') self.assertRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server.op_endpoint) + server.CheckIDRequest.fromMessage, msg, + self.server.op_endpoint) def test_answerAllowNoEndpointOpenID1(self): """Test .allow() with an OpenID 1.x Message on a CheckIDRequest @@ -1168,10 +1210,13 @@ class TestCheckID(unittest.TestCase): """ identity = 'http://bambam.unittest/' reqmessage = Message.fromOpenIDArgs({ - 'identity': identity, - 'trust_root': 'http://bar.unittest/', - 'return_to': 'http://bar.unittest/999', - }) + 'identity': + identity, + 'trust_root': + 'http://bar.unittest/', + 'return_to': + 'http://bar.unittest/999', + }) self.request = server.CheckIDRequest.fromMessage(reqmessage, None) answer = self.request.answer(True) @@ -1179,22 +1224,21 @@ class TestCheckID(unittest.TestCase): ('mode', 'id_res'), ('return_to', self.request.return_to), ('identity', identity), - ] + ] for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) - self.assertEqual( - expected, actual, - "%s: expected %s, got %s" % (k, expected, actual)) + self.assertEqual(expected, actual, + "%s: expected %s, got %s" % (k, expected, actual)) self.assertTrue(answer.fields.hasKey(OPENID_NS, 'response_nonce')) self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) self.assertTrue(answer.fields.namespaces.isImplicit(OPENID1_NS)) # One for nonce (OpenID v1 namespace is implicit) - self.assertEqual(len(answer.fields.toPostArgs()), - len(expected_list) + 1, - answer.fields.toPostArgs()) + self.assertEqual( + len(answer.fields.toPostArgs()), + len(expected_list) + 1, answer.fields.toPostArgs()) def test_answerImmediateDenyOpenID2(self): """Look for mode=setup_needed in checkid_immediate negative @@ -1212,8 +1256,8 @@ class TestCheckID(unittest.TestCase): self.assertEqual(answer.request, self.request) self.assertEqual(len(answer.fields.toPostArgs()), 3, answer.fields) self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID2_NS) - self.assertEqual(answer.fields.getArg(OPENID_NS, 'mode'), - 'setup_needed') + self.assertEqual( + answer.fields.getArg(OPENID_NS, 'mode'), 'setup_needed') usu = answer.fields.getArg(OPENID_NS, 'user_setup_url') expected_substr = 'openid.claimed_id=http%3A%2F%2Fclaimed-id.test%2F' @@ -1233,8 +1277,9 @@ class TestCheckID(unittest.TestCase): self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) self.assertTrue(answer.fields.namespaces.isImplicit(OPENID1_NS)) self.assertEqual(answer.fields.getArg(OPENID_NS, 'mode'), 'id_res') - self.assertTrue(answer.fields.getArg( - OPENID_NS, 'user_setup_url', '').startswith(server_url)) + self.assertTrue( + answer.fields.getArg(OPENID_NS, 'user_setup_url', '').startswith( + server_url)) def test_answerSetupDeny(self): answer = self.request.answer(False) @@ -1260,8 +1305,9 @@ class TestCheckID(unittest.TestCase): rt, query_string = url.split('?') self.assertEqual(self.request.return_to, rt) query = dict(parse_qsl(query_string)) - self.assertEqual(query, {'openid.mode': 'cancel', - 'openid.ns': OPENID2_NS}) + self.assertEqual(query, + {'openid.mode': 'cancel', + 'openid.ns': OPENID2_NS}) def test_getCancelURLimmed(self): self.request.mode = 'checkid_immediate' @@ -1279,8 +1325,7 @@ class TestCheckIDExtension(unittest.TestCase): trust_root='http://bar.unittest/', return_to='http://bar.unittest/999', immediate=False, - op_endpoint=self.server.op_endpoint, - ) + op_endpoint=self.server.op_endpoint, ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields.setArg(OPENID_NS, 'mode', 'id_res') @@ -1289,21 +1334,23 @@ class TestCheckIDExtension(unittest.TestCase): def test_addField(self): namespace = 'something:' self.response.fields.setArg(namespace, 'bright', 'potato') - self.assertEqual(self.response.fields.getArgs(OPENID_NS), { + self.assertEqual( + self.response.fields.getArgs(OPENID_NS), { 'blue': 'star', 'mode': 'id_res', - }) + }) - self.assertEqual(self.response.fields.getArgs(namespace), - {'bright': 'potato'}) + self.assertEqual( + self.response.fields.getArgs(namespace), {'bright': 'potato'}) def test_addFields(self): namespace = 'mi5:' - args = {'tangy': 'suspenders', - 'bravo': 'inclusion'} + args = {'tangy': 'suspenders', 'bravo': 'inclusion'} self.response.fields.updateArgs(namespace, args) - self.assertEqual(self.response.fields.getArgs(OPENID_NS), - {'blue': 'star', 'mode': 'id_res'}) + self.assertEqual( + self.response.fields.getArgs(OPENID_NS), + {'blue': 'star', + 'mode': 'id_res'}) self.assertEqual(self.response.fields.getArgs(namespace), args) @@ -1340,23 +1387,20 @@ class TestCheckAuth(unittest.TestCase): 'openid.sig': 'signarture', 'one': 'alpha', 'two': 'beta', - }) - self.request = server.CheckAuthRequest( - self.assoc_handle, self.message) + }) + self.request = server.CheckAuthRequest(self.assoc_handle, self.message) self.signatory = MockSignatory((True, self.assoc_handle)) def test_valid(self): r = self.request.answer(self.signatory) - self.assertEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'true'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) self.assertEqual(r.request, self.request) def test_invalid(self): self.signatory.isValid = False r = self.request.answer(self.signatory) - self.assertEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'false'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'false'}) def test_replay(self): """Don't validate the same response twice. @@ -1373,15 +1417,15 @@ class TestCheckAuth(unittest.TestCase): """ r = self.request.answer(self.signatory) r = self.request.answer(self.signatory) - self.assertEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'false'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'false'}) def test_invalidatehandle(self): self.request.invalidate_handle = "bogusHandle" r = self.request.answer(self.signatory) - self.assertEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'true', - 'invalidate_handle': "bogusHandle"}) + self.assertEqual( + r.fields.getArgs(OPENID_NS), + {'is_valid': 'true', + 'invalidate_handle': "bogusHandle"}) self.assertEqual(r.request, self.request) def test_invalidatehandleNo(self): @@ -1389,8 +1433,7 @@ class TestCheckAuth(unittest.TestCase): self.signatory.assocs.append((False, 'goodhandle')) self.request.invalidate_handle = assoc_handle r = self.request.answer(self.signatory) - self.assertEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'true'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) class TestAssociate(unittest.TestCase): @@ -1430,6 +1473,7 @@ class TestAssociate(unittest.TestCase): if not cryptutil.SHA256_AVAILABLE: warnings.warn("Not running SHA256 tests.") else: + def test_dhSHA256(self): self.assoc = self.signatory.createAssociation( dumb=False, assoc_type='HMAC-SHA256') @@ -1460,51 +1504,57 @@ class TestAssociate(unittest.TestCase): s256_session = DiffieHellmanSHA256ConsumerSession() - invalid_s256 = {'openid.assoc_type': 'HMAC-SHA1', - 'openid.session_type': 'DH-SHA256'} + invalid_s256 = { + 'openid.assoc_type': 'HMAC-SHA1', + 'openid.session_type': 'DH-SHA256' + } invalid_s256.update(s256_session.getRequest()) - invalid_s256_2 = {'openid.assoc_type': 'MONKEY-PIRATE', - 'openid.session_type': 'DH-SHA256'} + invalid_s256_2 = { + 'openid.assoc_type': 'MONKEY-PIRATE', + 'openid.session_type': 'DH-SHA256' + } invalid_s256_2.update(s256_session.getRequest()) bad_request_argss = [ invalid_s256, invalid_s256_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) self.assertRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, - message) + server.AssociateRequest.fromMessage, message) def test_protoError(self): from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession s1_session = DiffieHellmanSHA1ConsumerSession() - invalid_s1 = {'openid.assoc_type': 'HMAC-SHA256', - 'openid.session_type': 'DH-SHA1'} + invalid_s1 = { + 'openid.assoc_type': 'HMAC-SHA256', + 'openid.session_type': 'DH-SHA1' + } invalid_s1.update(s1_session.getRequest()) invalid_s1_2 = { 'openid.assoc_type': 'ROBOT-NINJA', 'openid.session_type': 'DH-SHA1' - } + } invalid_s1_2.update(s1_session.getRequest()) bad_request_argss = [ - {'openid.assoc_type':'Wha?'}, + { + 'openid.assoc_type': 'Wha?' + }, invalid_s1, invalid_s1_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) self.assertRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, - message) + server.AssociateRequest.fromMessage, message) def test_protoErrorFields(self): @@ -1515,7 +1565,7 @@ class TestAssociate(unittest.TestCase): openid1_args = { 'openid.identitiy': 'invalid', 'openid.mode': 'checkid_setup', - } + } openid2_args = dict(openid1_args) openid2_args.update({'openid.ns': OPENID2_NS}) @@ -1523,16 +1573,16 @@ class TestAssociate(unittest.TestCase): # Check presence of optional fields in both protocol versions openid1_msg = Message.fromPostArgs(openid1_args) - p = server.ProtocolError(openid1_msg, error, - contact=contact, reference=reference) + p = server.ProtocolError( + openid1_msg, error, contact=contact, reference=reference) reply = p.toMessage() self.assertEqual(reply.getArg(OPENID_NS, 'reference'), reference) self.assertEqual(reply.getArg(OPENID_NS, 'contact'), contact) openid2_msg = Message.fromPostArgs(openid2_args) - p = server.ProtocolError(openid2_msg, error, - contact=contact, reference=reference) + p = server.ProtocolError( + openid2_msg, error, contact=contact, reference=reference) reply = p.toMessage() self.assertEqual(reply.getArg(OPENID_NS, 'reference'), reference) @@ -1561,8 +1611,8 @@ class TestAssociate(unittest.TestCase): self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") self.assertEqual(rfg("assoc_handle"), self.assoc.handle) - self.failUnlessExpiresInMatches( - response.fields, self.signatory.SECRET_LIFETIME) + self.failUnlessExpiresInMatches(response.fields, + self.signatory.SECRET_LIFETIME) # remember, oidutil.toBase64 returns bytes... r_mac_key = rfg("mac_key").encode('utf-8') @@ -1579,7 +1629,7 @@ class TestAssociate(unittest.TestCase): 'openid.mode': 'associate', 'openid.assoc_type': 'HMAC-SHA1', 'openid.session_type': 'no-encryption', - } + } self.request = server.AssociateRequest.fromMessage( Message.fromPostArgs(args)) @@ -1593,8 +1643,8 @@ class TestAssociate(unittest.TestCase): self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") self.assertEqual(rfg("assoc_handle"), self.assoc.handle) - self.failUnlessExpiresInMatches( - response.fields, self.signatory.SECRET_LIFETIME) + self.failUnlessExpiresInMatches(response.fields, + self.signatory.SECRET_LIFETIME) # rfg gets from the response which will return str; oidutil.toBase64 # returns bytes. Make them comparable by bytes-ifying the mac_key @@ -1614,8 +1664,8 @@ class TestAssociate(unittest.TestCase): self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") self.assertEqual(rfg("assoc_handle"), self.assoc.handle) - self.failUnlessExpiresInMatches( - response.fields, self.signatory.SECRET_LIFETIME) + self.failUnlessExpiresInMatches(response.fields, + self.signatory.SECRET_LIFETIME) # remember, oidutil.toBase64 returns bytes... r_mac_key = rfg("mac_key").encode("utf-8") @@ -1636,8 +1686,7 @@ class TestAssociate(unittest.TestCase): response = self.request.answerUnsupported( message=message, preferred_session_type=allowed_sess, - preferred_association_type=allowed_assoc, - ) + preferred_association_type=allowed_assoc, ) rfg = lambda f: response.fields.getArg(OPENID_NS, f) self.assertEqual(rfg('error_code'), 'unsupported-type') self.assertEqual(rfg('assoc_type'), allowed_assoc) @@ -1680,6 +1729,7 @@ class TestServer(unittest.TestCase, CatchLogs): monkeycalled.inc() r = server.OpenIDResponse(request) return r + self.server.openid_monkeymode = monkeyDo request = server.OpenIDRequest() request.mode = "monkeymode" @@ -1690,8 +1740,9 @@ class TestServer(unittest.TestCase, CatchLogs): def test_associate(self): request = server.AssociateRequest.fromMessage(Message.fromPostArgs({})) response = self.server.openid_associate(request) - self.assertTrue(response.fields.hasKey(OPENID_NS, "assoc_handle"), - "No assoc_handle here: %s" % (response.fields,)) + self.assertTrue( + response.fields.hasKey(OPENID_NS, "assoc_handle"), + "No assoc_handle here: %s" % (response.fields, )) def test_associate2(self): """Associate when the server has no allowed association types @@ -1705,7 +1756,7 @@ class TestServer(unittest.TestCase, CatchLogs): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', - }) + }) request = server.AssociateRequest.fromMessage(msg) @@ -1727,7 +1778,7 @@ class TestServer(unittest.TestCase, CatchLogs): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', - }) + }) request = server.AssociateRequest.fromMessage(msg) response = self.server.openid_associate(request) @@ -1735,14 +1786,15 @@ class TestServer(unittest.TestCase, CatchLogs): self.assertTrue(response.fields.hasKey(OPENID_NS, "error")) self.assertTrue(response.fields.hasKey(OPENID_NS, "error_code")) self.assertFalse(response.fields.hasKey(OPENID_NS, "assoc_handle")) - self.assertEqual(response.fields.getArg(OPENID_NS, "assoc_type"), - 'HMAC-SHA256') - self.assertEqual(response.fields.getArg(OPENID_NS, "session_type"), - 'DH-SHA256') + self.assertEqual( + response.fields.getArg(OPENID_NS, "assoc_type"), 'HMAC-SHA256') + self.assertEqual( + response.fields.getArg(OPENID_NS, "session_type"), 'DH-SHA256') if not cryptutil.SHA256_AVAILABLE: warnings.warn("Not running SHA256 tests.") else: + def test_associate4(self): """DH-SHA256 association session""" self.server.negotiator.setAllowedTypes( @@ -1752,9 +1804,11 @@ class TestServer(unittest.TestCase, CatchLogs): 'ALZgnx8N5Lgd7pCj8K86T/DDMFjJXSss1SKoLmxE72kJTzOtG6I2PaYrHX' 'xku4jMQWSsGfLJxwCZ6280uYjUST/9NWmuAfcrBfmDHIBc3H8xh6RBnlXJ' '1WxJY3jHd5k1/ZReyRZOxZTKdF/dnIqwF8ZXUwI6peV0TyS/K1fOfF/s', - 'openid.assoc_type': 'HMAC-SHA256', - 'openid.session_type': 'DH-SHA256', - } + 'openid.assoc_type': + 'HMAC-SHA256', + 'openid.session_type': + 'DH-SHA256', + } message = Message.fromPostArgs(query) request = server.AssociateRequest.fromMessage(message) response = self.server.openid_associate(request) @@ -1764,7 +1818,7 @@ class TestServer(unittest.TestCase, CatchLogs): """Make sure session_type is required in OpenID 2""" msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, - }) + }) self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) @@ -1788,8 +1842,8 @@ class TestSignatory(unittest.TestCase, CatchLogs): assoc_handle = '{assoc}{lookatme}' self.store.storeAssociation( self._normal_key, - association.Association.fromExpiresIn(60, assoc_handle, - 'sekrit', 'HMAC-SHA1')) + association.Association.fromExpiresIn(60, assoc_handle, 'sekrit', + 'HMAC-SHA1')) request.assoc_handle = assoc_handle request.namespace = OPENID1_NS response = server.OpenIDResponse(request) @@ -1797,13 +1851,13 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) self.assertEqual( - sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), - assoc_handle) - self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), - 'assoc_handle,azu,bar,foo,signed') + sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) + self.assertEqual( + sresponse.fields.getArg(OPENID_NS, 'signed'), + 'assoc_handle,azu,bar,foo,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) self.assertFalse(self.messages, self.messages) @@ -1817,14 +1871,15 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'bar': 'notsigned', 'azu': 'alsosigned', 'ns': OPENID2_NS, - }) + }) sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(assoc_handle) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertTrue(assoc) - self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), - 'assoc_handle,azu,bar,foo,ns,signed') + self.assertEqual( + sresponse.fields.getArg(OPENID_NS, 'signed'), + 'assoc_handle,azu,bar,foo,ns,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) self.assertFalse(self.messages, self.messages) @@ -1849,10 +1904,10 @@ class TestSignatory(unittest.TestCase, CatchLogs): assoc_handle = '{assoc}{lookatme}' self.store.storeAssociation( self._normal_key, - association.Association.fromExpiresIn(-10, assoc_handle, - 'sekrit', 'HMAC-SHA1')) - self.assertTrue(self.store.getAssociation( - self._normal_key, assoc_handle)) + association.Association.fromExpiresIn(-10, assoc_handle, 'sekrit', + 'HMAC-SHA1')) + self.assertTrue( + self.store.getAssociation(self._normal_key, assoc_handle)) request.assoc_handle = assoc_handle response = server.OpenIDResponse(request) @@ -1860,7 +1915,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1877,8 +1932,9 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the expired association is gone - self.assertFalse(self.store.getAssociation(self._normal_key, assoc_handle), - "expired association is still retrievable.") + self.assertFalse( + self.store.getAssociation(self._normal_key, assoc_handle), + "expired association is still retrievable.") # make sure the new key is a dumb mode association self.assertTrue( @@ -1898,7 +1954,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1923,18 +1979,23 @@ class TestSignatory(unittest.TestCase, CatchLogs): def test_verify(self): assoc_handle = '{vroom}{zoom}' - assoc = association.Association.fromExpiresIn( - 60, assoc_handle, 'sekrit', 'HMAC-SHA1') + assoc = association.Association.fromExpiresIn(60, assoc_handle, + 'sekrit', 'HMAC-SHA1') self.store.storeAssociation(self._dumb_key, assoc) signed = Message.fromPostArgs({ - 'openid.foo': 'bar', - 'openid.apple': 'orange', - 'openid.assoc_handle': assoc_handle, - 'openid.signed': 'apple,assoc_handle,foo,signed', - 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco=', - }) + 'openid.foo': + 'bar', + 'openid.apple': + 'orange', + 'openid.assoc_handle': + assoc_handle, + 'openid.signed': + 'apple,assoc_handle,foo,signed', + 'openid.sig': + 'uXoT1qm62/BB09Xbj98TQ8mlBco=', + }) verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(self.messages, self.messages) @@ -1942,18 +2003,23 @@ class TestSignatory(unittest.TestCase, CatchLogs): def test_verifyBadSig(self): assoc_handle = '{vroom}{zoom}' - assoc = association.Association.fromExpiresIn( - 60, assoc_handle, 'sekrit', 'HMAC-SHA1') + assoc = association.Association.fromExpiresIn(60, assoc_handle, + 'sekrit', 'HMAC-SHA1') self.store.storeAssociation(self._dumb_key, assoc) signed = Message.fromPostArgs({ - 'openid.foo': 'bar', - 'openid.apple': 'orange', - 'openid.assoc_handle': assoc_handle, - 'openid.signed': 'apple,assoc_handle,foo,signed', - 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='[::-1], - }) + 'openid.foo': + 'bar', + 'openid.apple': + 'orange', + 'openid.assoc_handle': + assoc_handle, + 'openid.signed': + 'apple,assoc_handle,foo,signed', + 'openid.sig': + 'uXoT1qm62/BB09Xbj98TQ8mlBco=' [::-1], + }) verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(self.messages, self.messages) @@ -1962,10 +2028,13 @@ class TestSignatory(unittest.TestCase, CatchLogs): def test_verifyBadHandle(self): assoc_handle = '{vroom}{zoom}' signed = Message.fromPostArgs({ - 'foo': 'bar', - 'apple': 'orange', - 'openid.sig': "Ylu0KcIR7PvNegB/K41KpnRgJl0=", - }) + 'foo': + 'bar', + 'apple': + 'orange', + 'openid.sig': + "Ylu0KcIR7PvNegB/K41KpnRgJl0=", + }) verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) @@ -1974,16 +2043,19 @@ class TestSignatory(unittest.TestCase, CatchLogs): def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" assoc_handle = '{vroom}{zoom}' - assoc = association.Association.fromExpiresIn( - 60, assoc_handle, 'sekrit', 'HMAC-SHA1') + assoc = association.Association.fromExpiresIn(60, assoc_handle, + 'sekrit', 'HMAC-SHA1') self.store.storeAssociation(self._dumb_key, assoc) signed = Message.fromPostArgs({ - 'foo': 'bar', - 'apple': 'orange', - 'openid.sig': "d71xlHtqnq98DonoSgoK/nD+QRM=", - }) + 'foo': + 'bar', + 'apple': + 'orange', + 'openid.sig': + "d71xlHtqnq98DonoSgoK/nD+QRM=", + }) verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) @@ -2004,8 +2076,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): def test_getAssocInvalid(self): ah = 'no-such-handle' - self.assertEqual( - self.signatory.getAssociation(ah, dumb=False), None) + self.assertEqual(self.signatory.getAssociation(ah, dumb=False), None) self.assertFalse(self.messages, self.messages) def test_getAssocDumbVsNormal(self): @@ -2039,8 +2110,8 @@ class TestSignatory(unittest.TestCase, CatchLogs): assoc = association.Association.fromExpiresIn(lifetime, assoc_handle, 'sekrit', 'HMAC-SHA1') - self.store.storeAssociation( - (dumb and self._dumb_key) or self._normal_key, assoc) + self.store.storeAssociation((dumb and self._dumb_key) or + self._normal_key, assoc) return assoc_handle def test_invalidate(self): diff --git a/openid/test/test_services.py b/openid/test/test_services.py index 77afce68c8367433732472bc4600b0e26bbd7109..980a6d048d1bb55150f59558346345e6e6f3b64e 100644 --- a/openid/test/test_services.py +++ b/openid/test/test_services.py @@ -18,6 +18,5 @@ class TestGetServiceEndpoints(unittest.TestCase): return result def test_catchXRDSError(self): - self.assertRaises(DiscoveryFailure, - services.getServiceEndpoints, - "http://example.invalid/sometest") + self.assertRaises(DiscoveryFailure, services.getServiceEndpoints, + "http://example.invalid/sometest") diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 1b360510caae0bdc356e567b609537c3d42e29c1..2d0cecc64b6632781501da4b48ad5a787ada9e20 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -38,7 +38,7 @@ class SupportsSRegTest(unittest.TestCase): endpoint = FakeEndpoint([]) self.assertFalse(sreg.supportsSReg(endpoint)) self.assertEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], - endpoint.checked_uris) + endpoint.checked_uris) def test_supported_1_1(self): endpoint = FakeEndpoint([sreg.ns_uri_1_1]) @@ -49,7 +49,7 @@ class SupportsSRegTest(unittest.TestCase): endpoint = FakeEndpoint([sreg.ns_uri_1_0]) self.assertTrue(sreg.supportsSReg(endpoint)) self.assertEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], - endpoint.checked_uris) + endpoint.checked_uris) class FakeMessage(object): @@ -91,21 +91,19 @@ class GetNSTest(unittest.TestCase): self.msg.openid1 = openid_version self.msg.namespaces.addAlias(sreg_version, alias) ns_uri = sreg.getSRegNS(self.msg) - self.assertEqual(self.msg.namespaces.getAlias(ns_uri), - alias) + self.assertEqual( + self.msg.namespaces.getAlias(ns_uri), alias) self.assertEqual(sreg_version, ns_uri) def test_openID1DefinedBadly(self): self.msg.openid1 = True self.msg.namespaces.addAlias('http://invalid/', 'sreg') - self.assertRaises(sreg.SRegNamespaceError, - sreg.getSRegNS, self.msg) + self.assertRaises(sreg.SRegNamespaceError, sreg.getSRegNS, self.msg) def test_openID2DefinedBadly(self): self.msg.openid1 = False self.msg.namespaces.addAlias('http://invalid/', 'sreg') - self.assertRaises(sreg.SRegNamespaceError, - sreg.getSRegNS, self.msg) + self.assertRaises(sreg.SRegNamespaceError, sreg.getSRegNS, self.msg) def test_openID2Defined_1_0(self): self.msg.namespaces.add(sreg.ns_uri_1_0) @@ -116,7 +114,7 @@ class GetNSTest(unittest.TestCase): args = { 'sreg.optional': 'nickname', 'sreg.required': 'dob', - } + } m = Message.fromOpenIDArgs(args) @@ -133,20 +131,15 @@ class SRegRequestTest(unittest.TestCase): self.assertEqual(sreg.ns_uri, req.ns_uri) def test_constructFields(self): - req = sreg.SRegRequest( - ['nickname'], - ['gender'], - 'http://policy', - 'http://sreg.ns_uri') + req = sreg.SRegRequest(['nickname'], ['gender'], 'http://policy', + 'http://sreg.ns_uri') self.assertEqual(['gender'], req.optional) self.assertEqual(['nickname'], req.required) self.assertEqual('http://policy', req.policy_url) self.assertEqual('http://sreg.ns_uri', req.ns_uri) def test_constructBadFields(self): - self.assertRaises( - ValueError, - sreg.SRegRequest, ['elvis']) + self.assertRaises(ValueError, sreg.SRegRequest, ['elvis']) def test_fromOpenIDRequest(self): args = {} @@ -201,7 +194,8 @@ class SRegRequestTest(unittest.TestCase): req = sreg.SRegRequest() self.assertRaises( ValueError, - req.parseExtensionArgs, {'required': 'beans'}, strict=True) + req.parseExtensionArgs, {'required': 'beans'}, + strict=True) def test_parseExtensionArgs_policy(self): req = sreg.SRegRequest() @@ -242,8 +236,10 @@ class SRegRequestTest(unittest.TestCase): def test_parseExtensionArgs_bothNonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional': 'nickname', - 'required': 'nickname'}) + req.parseExtensionArgs({ + 'optional': 'nickname', + 'required': 'nickname' + }) self.assertEqual([], req.optional) self.assertEqual(['nickname'], req.required) @@ -258,8 +254,12 @@ class SRegRequestTest(unittest.TestCase): def test_parseExtensionArgs_bothList(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional': 'nickname,email', - 'required': 'country,postcode'}, strict=True) + req.parseExtensionArgs( + { + 'optional': 'nickname,email', + 'required': 'country,postcode' + }, + strict=True) self.assertEqual(['nickname', 'email'], req.optional) self.assertEqual(['country', 'postcode'], req.required) @@ -295,13 +295,10 @@ class SRegRequestTest(unittest.TestCase): def test_requestField_bogus(self): req = sreg.SRegRequest() - self.assertRaises( - ValueError, - req.requestField, 'something else') + self.assertRaises(ValueError, req.requestField, 'something else') self.assertRaises( - ValueError, - req.requestField, 'something else', strict=True) + ValueError, req.requestField, 'something else', strict=True) def test_requestField(self): # Add all of the fields, one at a time @@ -388,24 +385,29 @@ class SRegRequestTest(unittest.TestCase): self.assertEqual({'optional': 'nickname'}, req.getExtensionArgs()) req.requestField('email') - self.assertEqual({'optional': 'nickname,email'}, - req.getExtensionArgs()) + self.assertEqual({ + 'optional': 'nickname,email' + }, req.getExtensionArgs()) req.requestField('gender', required=True) - self.assertEqual({'optional': 'nickname,email', - 'required': 'gender'}, - req.getExtensionArgs()) + self.assertEqual({ + 'optional': 'nickname,email', + 'required': 'gender' + }, req.getExtensionArgs()) req.requestField('postcode', required=True) - self.assertEqual({'optional': 'nickname,email', - 'required': 'gender,postcode'}, - req.getExtensionArgs()) + self.assertEqual({ + 'optional': 'nickname,email', + 'required': 'gender,postcode' + }, req.getExtensionArgs()) req.policy_url = 'http://policy.invalid/' - self.assertEqual({'optional': 'nickname,email', - 'required': 'gender,postcode', - 'policy_url': 'http://policy.invalid/'}, - req.getExtensionArgs()) + self.assertEqual({ + 'optional': 'nickname,email', + 'required': 'gender,postcode', + 'policy_url': 'http://policy.invalid/' + }, req.getExtensionArgs()) + data = { 'nickname': 'linusaur', @@ -416,7 +418,7 @@ data = { 'email': 'president@whitehouse.gov', 'dob': '0000-00-00', 'language': 'en-us', - } +} class DummySuccessResponse(object): @@ -442,7 +444,7 @@ class SRegResponseTest(unittest.TestCase): def test_fromSuccessResponse_signed(self): message = Message.fromOpenIDArgs({ 'sreg.nickname': 'The Mad Stork', - }) + }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp) self.assertFalse(sreg_resp) @@ -450,19 +452,19 @@ class SRegResponseTest(unittest.TestCase): def test_fromSuccessResponse_unsigned(self): message = Message.fromOpenIDArgs({ 'sreg.nickname': 'The Mad Stork', - }) + }) success_resp = DummySuccessResponse(message, {}) - sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, - signed_only=False) + sreg_resp = sreg.SRegResponse.fromSuccessResponse( + success_resp, signed_only=False) self.assertEqual([('nickname', 'The Mad Stork')], - list(sreg_resp.items())) + list(sreg_resp.items())) class SendFieldsTest(unittest.TestCase): def test(self): # Create a request message with simple registration fields - sreg_req = sreg.SRegRequest(required=['nickname', 'email'], - optional=['fullname']) + sreg_req = sreg.SRegRequest( + required=['nickname', 'email'], optional=['fullname']) req_msg = Message() req_msg.updateArgs(sreg.ns_uri, sreg_req.getExtensionArgs()) @@ -486,10 +488,11 @@ class SendFieldsTest(unittest.TestCase): # Extract the fields that were sent sreg_data_resp = resp_msg.getArgs(sreg.ns_uri) self.assertEqual({ - 'nickname': 'linusaur', - 'email': 'president@whitehouse.gov', - 'fullname': 'Leonhard Euler', - }, sreg_data_resp) + 'nickname': 'linusaur', + 'email': 'president@whitehouse.gov', + 'fullname': 'Leonhard Euler', + }, sreg_data_resp) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index a06e794b80006c0b59d306806d6dc8641fa3e7c9..8d8ee6c8d87faca0c19e5ad178f1ca7c120ba26d 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -2,6 +2,7 @@ import unittest from openid import oidutil + class SymbolTest(unittest.TestCase): def test_selfEquality(self): s = oidutil.Symbol('xxx') @@ -31,5 +32,6 @@ class SymbolTest(unittest.TestCase): y = oidutil.Symbol('yyy') self.assertTrue(x != y) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 406ad3d08f6fa45b92c15e0b6a3dac4f842b1160..e1258775701868cbbd86d521db5954371f06bcce 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -2,6 +2,7 @@ import os import unittest import openid.urinorm + class UrinormTest(unittest.TestCase): def __init__(self, desc, case, expected): unittest.TestCase.__init__(self) @@ -52,6 +53,7 @@ def pyUnitTests(): tests = parseTests(test_data) return unittest.TestSuite(tests) + if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(pyUnitTests()) diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 069af7e6e9c8ced611917f9cfaf2555c9d70df97..3b62a6749ff2a83516f87526bd2950381657ad66 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -9,6 +9,7 @@ from openid.consumer import discover def const(result): """Return a function that ignores any arguments and just returns the specified result""" + def constResult(*args, **kwargs): return result @@ -33,45 +34,50 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.claimed_id = 'bogus' msg = message.Message.fromOpenIDArgs({}) - self.failUnlessProtocolError( - 'Missing required field openid.identity', - self.consumer._verifyDiscoveryResults, msg, endpoint) + self.failUnlessProtocolError('Missing required field openid.identity', + self.consumer._verifyDiscoveryResults, + msg, endpoint) self.failUnlessLogEmpty() def test_openID1NoEndpoint(self): msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) - self.assertRaises(RuntimeError, - self.consumer._verifyDiscoveryResults, msg) + self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, + msg) self.failUnlessLogEmpty() def test_openID2NoOPEndpointArg(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) - self.assertRaises(KeyError, - self.consumer._verifyDiscoveryResults, msg) + self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2LocalIDNoClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, - 'op_endpoint': 'Phone Home', - 'identity': 'Jose Lius Borges'}) - self.failUnlessProtocolError( - 'openid.identity is present without', - self.consumer._verifyDiscoveryResults, msg) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'identity': 'Jose Lius Borges' + }) + self.failUnlessProtocolError('openid.identity is present without', + self.consumer._verifyDiscoveryResults, + msg) self.failUnlessLogEmpty() def test_openID2NoLocalIDClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, - 'op_endpoint': 'Phone Home', - 'claimed_id': 'Manuel Noriega'}) - self.failUnlessProtocolError( - 'openid.claimed_id is present without', - self.consumer._verifyDiscoveryResults, msg) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'claimed_id': 'Manuel Noriega' + }) + self.failUnlessProtocolError('openid.claimed_id is present without', + self.consumer._verifyDiscoveryResults, + msg) self.failUnlessLogEmpty() def test_openID2NoIdentifiers(self): op_endpoint = 'Phone Home' - msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, - 'op_endpoint': op_endpoint}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'op_endpoint': op_endpoint + }) result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.assertTrue(result_endpoint.isOPIdentifier()) self.assertEqual(op_endpoint, result_endpoint.server_url) @@ -83,11 +89,12 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): sentinel = discover.OpenIDServiceEndpoint() sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID2_NS, - 'identity': 'sour grapes', - 'claimed_id': 'monkeysoft', - 'op_endpoint': op_endpoint}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint + }) result = self.consumer._verifyDiscoveryResults(msg) self.assertEqual(sentinel, result) self.failUnlessLogMatches('No pre-discovered') @@ -101,11 +108,12 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): sentinel = discover.OpenIDServiceEndpoint() sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID2_NS, - 'identity': 'sour grapes', - 'claimed_id': 'monkeysoft', - 'op_endpoint': op_endpoint}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint + }) result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.assertEqual(sentinel, result) self.failUnlessLogMatches('Error attempting to use stored', @@ -118,11 +126,16 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.server_url = 'Phone Home' endpoint.type_uris = [discover.OPENID_2_0_TYPE] - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID2_NS, - 'identity': endpoint.local_id, - 'claimed_id': endpoint.claimed_id, - 'op_endpoint': endpoint.server_url}) + msg = message.Message.fromOpenIDArgs({ + 'ns': + message.OPENID2_NS, + 'identity': + endpoint.local_id, + 'claimed_id': + endpoint.claimed_id, + 'op_endpoint': + endpoint.server_url + }) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertTrue(result is endpoint) self.failUnlessLogEmpty() @@ -144,11 +157,16 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): self.consumer._discoverAndVerify = discoverAndVerify - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID2_NS, - 'identity': endpoint.local_id, - 'claimed_id': endpoint.claimed_id, - 'op_endpoint': endpoint.server_url}) + msg = message.Message.fromOpenIDArgs({ + 'ns': + message.OPENID2_NS, + 'identity': + endpoint.local_id, + 'claimed_id': + endpoint.claimed_id, + 'op_endpoint': + endpoint.server_url + }) try: r = self.consumer._verifyDiscoveryResults(msg, endpoint) @@ -156,7 +174,7 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): # Should we make more ProtocolError subclasses? self.assertTrue(str(e), text) else: - self.fail("expected ProtocolError, %r returned." % (r,)) + self.fail("expected ProtocolError, %r returned." % (r, )) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -168,9 +186,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.server_url = 'Phone Home' endpoint.type_uris = [discover.OPENID_1_1_TYPE] - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID1_NS, - 'identity': endpoint.local_id}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID1_NS, + 'identity': endpoint.local_id + }) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertTrue(result is endpoint) self.failUnlessLogEmpty() @@ -179,7 +198,6 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): class VerifiedError(Exception): pass - def discoverAndVerify(claimed_id, _to_match): raise VerifiedError @@ -191,13 +209,13 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.server_url = 'Phone Home' endpoint.type_uris = [discover.OPENID_2_0_TYPE] - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID1_NS, - 'identity': endpoint.local_id}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID1_NS, + 'identity': endpoint.local_id + }) - self.assertRaises( - VerifiedError, - self.consumer._verifyDiscoveryResults, msg, endpoint) + self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, + msg, endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -211,11 +229,12 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.server_url = 'Phone Home' endpoint.type_uris = [discover.OPENID_2_0_TYPE] - msg = message.Message.fromOpenIDArgs( - {'ns': message.OPENID2_NS, - 'identity': endpoint.local_id, - 'claimed_id': claimed_id_frag, - 'op_endpoint': endpoint.server_url}) + msg = message.Message.fromOpenIDArgs({ + 'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, + 'claimed_id': claimed_id_frag, + 'op_endpoint': endpoint.server_url + }) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result.local_id, endpoint.local_id) @@ -231,7 +250,8 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint = None resp_mesg = message.Message.fromOpenIDArgs({ 'ns': message.OPENID1_NS, - 'identity': claimed_id}) + 'identity': claimed_id + }) # Pass the OpenID 1 claimed_id this way since we're passing # None for the endpoint. resp_mesg.setArg(message.BARE_NS, 'openid1_claimed_id', claimed_id) @@ -247,10 +267,11 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): discovered_services = [expected_endpoint] self.consumer._discover = lambda *args: ('unused', discovered_services) - actual_endpoint = self.consumer._verifyDiscoveryResults( - resp_mesg, endpoint) + actual_endpoint = self.consumer._verifyDiscoveryResults(resp_mesg, + endpoint) self.assertTrue(actual_endpoint is expected_endpoint) + # XXX: test the implementation of _discoverAndVerify @@ -271,5 +292,6 @@ class TestVerifyDiscoverySingle(TestIdRes): self.assertEqual(result, None) self.failUnlessLogEmpty() + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index bbb8ed5547896bf106c19dd3fbe19d34ff2992e3..817357366679ce14b9d1a5c9df9a924e3bdcc90f 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -14,8 +14,8 @@ class XriDiscoveryTestCase(TestCase): class XriEscapingTestCase(TestCase): def test_escaping_percents(self): - self.assertEqual(xri.escapeForIRI('@example/abc%2Fd/ef'), - '@example/abc%252Fd/ef') + self.assertEqual( + xri.escapeForIRI('@example/abc%2Fd/ef'), '@example/abc%252Fd/ef') def test_escaping_xref(self): # no escapes @@ -23,12 +23,12 @@ class XriEscapingTestCase(TestCase): self.assertEqual('@example/foo/(@bar)', esc('@example/foo/(@bar)')) # escape slashes self.assertEqual('@example/foo/(@bar%2Fbaz)', - esc('@example/foo/(@bar/baz)')) + esc('@example/foo/(@bar/baz)')) self.assertEqual('@example/foo/(@bar%2Fbaz)/(+a%2Fb)', - esc('@example/foo/(@bar/baz)/(+a/b)')) + esc('@example/foo/(@bar/baz)/(+a/b)')) # escape query ? and fragment # self.assertEqual('@example/foo/(@baz%3Fp=q%23r)?i=j#k', - esc('@example/foo/(@baz?p=q#r)?i=j#k')) + esc('@example/foo/(@baz?p=q#r)?i=j#k')) class XriTransformationTestCase(TestCase): @@ -76,6 +76,7 @@ class TestGetRootAuthority(TestCase): def test(self): actual_root = xri.rootAuthority(the_xri) self.assertEqual(actual_root, xri.XRI(expected_root)) + return test test_at = mkTest("@foo", "@") @@ -97,6 +98,7 @@ class TestGetRootAuthority(TestCase): ##("example.com*bar/(=baz)", "example.com*bar"), ##("baz.example.com!01/foo", "baz.example.com!01"), + if __name__ == '__main__': import unittest unittest.main() diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index 321c37330eadb8f995381583153389ac292991a0..752d0fe1cdbcd5c3eb31a326a0c05e891174be72 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -1,7 +1,7 @@ - from unittest import TestCase from openid.yadis import xrires + class ProxyQueryTestCase(TestCase): def setUp(self): self.proxy_url = 'http://xri.example.com/' @@ -9,7 +9,6 @@ class ProxyQueryTestCase(TestCase): self.servicetype = 'xri://+i-service*(+forwarding)*($v*1.0)' self.servicetype_enc = 'xri%3A%2F%2F%2Bi-service%2A%28%2Bforwarding%29%2A%28%24v%2A1.0%29' - def test_proxy_url(self): st = self.servicetype ste = self.servicetype_enc @@ -18,17 +17,16 @@ class ProxyQueryTestCase(TestCase): h = self.proxy_url self.assertEqual(h + '=foo?' + args_esc, pqu('=foo', st)) self.assertEqual(h + '=foo/bar?baz&' + args_esc, - pqu('=foo/bar?baz', st)) + pqu('=foo/bar?baz', st)) self.assertEqual(h + '=foo/bar?baz=quux&' + args_esc, - pqu('=foo/bar?baz=quux', st)) + pqu('=foo/bar?baz=quux', st)) self.assertEqual(h + '=foo/bar?mi=fa&so=la&' + args_esc, - pqu('=foo/bar?mi=fa&so=la', st)) + pqu('=foo/bar?mi=fa&so=la', st)) # With no service endpoint selection. args_esc = "_xrd_r=application%2Fxrds%2Bxml%3Bsep%3Dfalse" self.assertEqual(h + '=foo?' + args_esc, pqu('=foo', None)) - def test_proxy_url_qmarks(self): st = self.servicetype ste = self.servicetype_enc @@ -36,5 +34,4 @@ class ProxyQueryTestCase(TestCase): pqu = self.proxy.queryURL h = self.proxy_url self.assertEqual(h + '=foo/bar??' + args_esc, pqu('=foo/bar?', st)) - self.assertEqual(h + '=foo/bar????' + args_esc, - pqu('=foo/bar???', st)) + self.assertEqual(h + '=foo/bar????' + args_esc, pqu('=foo/bar???', st)) diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 31464e3fd4070f51abe16c9947d1eff05a570a57..561bba89b3661523509c83b9dcb930ad305fe2c8 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - """Tests for yadis.discover. @todo: Now that yadis.discover uses urljr.fetchers, we should be able to do @@ -55,8 +54,8 @@ class TestFetcher(object): try: data = discoverdata.generateSample(path, self.base_url) except KeyError: - return fetchers.HTTPResponse(status=404, final_url=current_url, - headers={}, body='') + return fetchers.HTTPResponse( + status=404, final_url=current_url, headers={}, body='') response = mkResponse(data) if response.status in [301, 302, 303, 307]: @@ -76,7 +75,7 @@ class TestSecondGet(unittest.TestCase): if self.count == 1: headers = { 'X-XRDS-Location'.lower(): 'http://unittest/404', - } + } return fetchers.HTTPResponse(uri, 200, headers, '') else: return fetchers.HTTPResponse(uri, 404) @@ -109,14 +108,11 @@ class _TestCase(unittest.TestCase): unittest.TestCase.__init__(self, methodName='runCustomTest') def setUp(self): - fetchers.setDefaultFetcher(TestFetcher(self.base_url), - wrap_exceptions=False) + fetchers.setDefaultFetcher( + TestFetcher(self.base_url), wrap_exceptions=False) self.input_url, self.expected = discoverdata.generateResult( - self.base_url, - self.input_name, - self.id_name, - self.result_name, + self.base_url, self.input_name, self.id_name, self.result_name, self.success) def tearDown(self): @@ -124,21 +120,20 @@ class _TestCase(unittest.TestCase): def runCustomTest(self): if self.expected is DiscoveryFailure: - self.assertRaises(DiscoveryFailure, - discover, self.input_url) + self.assertRaises(DiscoveryFailure, discover, self.input_url) else: result = discover(self.input_url) self.assertEqual(self.input_url, result.request_uri) msg = 'Identity URL mismatch: actual = %r, expected = %r' % ( result.normalized_uri, self.expected.normalized_uri) - self.assertEqual( - self.expected.normalized_uri, result.normalized_uri, msg) + self.assertEqual(self.expected.normalized_uri, + result.normalized_uri, msg) msg = 'Content mismatch: actual = %r, expected = %r' % ( result.response_text, self.expected.response_text) - self.assertEqual( - self.expected.response_text, result.response_text, msg) + self.assertEqual(self.expected.response_text, result.response_text, + msg) expected_keys = dir(self.expected) expected_keys.sort() @@ -161,9 +156,7 @@ class _TestCase(unittest.TestCase): except AttributeError: # run before setUp, or if setUp did not complete successfully. n = self.input_name - return "%s (%s)" % ( - n, - self.__class__.__module__) + return "%s (%s)" % (n, self.__class__.__module__) def pyUnitTests(): @@ -179,5 +172,6 @@ def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + if __name__ == '__main__': test() diff --git a/openid/test/trustroot.py b/openid/test/trustroot.py index 433dea0f1e12f0898c36f36fecde96e862a7cf64..3d39bfa29af4842cf051cd7cf93e08600828ef06 100644 --- a/openid/test/trustroot.py +++ b/openid/test/trustroot.py @@ -2,6 +2,7 @@ import os import unittest from openid.server.trustroot import TrustRoot + class _ParseTest(unittest.TestCase): def __init__(self, sanity, desc, case): unittest.TestCase.__init__(self) @@ -21,6 +22,7 @@ class _ParseTest(unittest.TestCase): else: assert tr is None, tr + class _MatchTest(unittest.TestCase): def __init__(self, match, desc, line): unittest.TestCase.__init__(self) @@ -43,6 +45,7 @@ class _MatchTest(unittest.TestCase): else: assert not match + def getTests(t, grps, head, dat): tests = [] top = head.strip() @@ -59,6 +62,7 @@ def getTests(t, grps, head, dat): i += 2 return tests + def parseTests(data): parts = list(map(str.strip, data.split('=' * 40 + '\n'))) assert not parts[0] diff --git a/openid/urinorm.py b/openid/urinorm.py index 90fc7da5f97980bd0f093cca5f3f81dfd0c0ffa4..485245c5e5def8214ca2a1a34c0b4644f87efe47 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -13,21 +13,22 @@ uri_re = re.compile(uri_pattern) # # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" -uri_illegal_char_re = re.compile( - "[^-A-Za-z0-9:/?#[\]@!$&'()*+,;=._~%]", re.UNICODE) +uri_illegal_char_re = re.compile("[^-A-Za-z0-9:/?#[\]@!$&'()*+,;=._~%]", + re.UNICODE) authority_pattern = r'^([^@]*@)?([^:]*)(:.*)?' authority_re = re.compile(authority_pattern) - pct_encoded_pattern = r'%([0-9A-Fa-f]{2})' pct_encoded_re = re.compile(pct_encoded_pattern) - _unreserved = [False] * 256 -for _ in range(ord('A'), ord('Z') + 1): _unreserved[_] = True -for _ in range(ord('0'), ord('9') + 1): _unreserved[_] = True -for _ in range(ord('a'), ord('z') + 1): _unreserved[_] = True +for _ in range(ord('A'), ord('Z') + 1): + _unreserved[_] = True +for _ in range(ord('0'), ord('9') + 1): + _unreserved[_] = True +for _ in range(ord('a'), ord('z') + 1): + _unreserved[_] = True _unreserved[ord('-')] = True _unreserved[ord('.')] = True _unreserved[ord('_')] = True @@ -112,15 +113,15 @@ def urinorm(uri): scheme = scheme.lower() if scheme not in ('http', 'https'): - raise ValueError('Not an absolute HTTP or HTTPS URI: %r' % (uri,)) + raise ValueError('Not an absolute HTTP or HTTPS URI: %r' % (uri, )) authority = uri_mo.group(4) if authority is None: - raise ValueError('Not an absolute URI: %r' % (uri,)) + raise ValueError('Not an absolute URI: %r' % (uri, )) authority_mo = authority_re.match(authority) if authority_mo is None: - raise ValueError('URI does not have a valid authority: %r' % (uri,)) + raise ValueError('URI does not have a valid authority: %r' % (uri, )) userinfo, host, port = authority_mo.groups() @@ -135,8 +136,7 @@ def urinorm(uri): host = host.lower() if port: - if (port == ':' or - (scheme == 'http' and port == ':80') or + if (port == ':' or (scheme == 'http' and port == ':80') or (scheme == 'https' and port == ':443')): port = '' else: diff --git a/openid/yadis/__init__.py b/openid/yadis/__init__.py index eb920f444bfe74c4903a7f84375d317426091cf1..ef806e87508dac5c9701356b287e311b281af402 100644 --- a/openid/yadis/__init__.py +++ b/openid/yadis/__init__.py @@ -1,4 +1,4 @@ -#-*-coding: utf-8-*- +#-*- coding: utf-8 -*- __all__ = [ 'constants', diff --git a/openid/yadis/accept.py b/openid/yadis/accept.py index d750813106c429835677d0627bcc4a84ddd076a3..2f18aa662a7ab2751b4666a82b23d2b6b121379e 100644 --- a/openid/yadis/accept.py +++ b/openid/yadis/accept.py @@ -2,6 +2,7 @@ supporting server-directed content negotiation. """ + def generateAcceptHeader(*elements): """Generate an accept header value @@ -18,7 +19,7 @@ def generateAcceptHeader(*elements): if q > 1 or q <= 0: raise ValueError('Invalid preference factor: %r' % q) - qs = '%0.1f' % (q,) + qs = '%0.1f' % (q, ) parts.append((qs, mtype)) @@ -32,6 +33,7 @@ def generateAcceptHeader(*elements): return ', '.join(chunks) + def parseAcceptHeader(value): """Parse an accept header, ignoring any accept-extensions @@ -71,6 +73,7 @@ def parseAcceptHeader(value): accept.reverse() return [(main, sub, q) for (q, main, sub) in accept] + def matchTypes(accept_types, have_types): """Given the result of parsing an Accept: header, and the available MIME types, return the acceptable types with their @@ -118,6 +121,7 @@ def matchTypes(accept_types, have_types): accepted_list.sort() return [(mtype, q) for (_, _, q, mtype) in accepted_list] + def getAcceptable(accept_header, have_types): """Parse the accept header and return a list of available types in preferred order. If a type is unacceptable, it will not be in the diff --git a/openid/yadis/constants.py b/openid/yadis/constants.py index 75ff96eff9c1ba48205ce5c3a72df0809a397145..5d3072387ebda2f5c8f1b9d8bc63f27a80208af9 100644 --- a/openid/yadis/constants.py +++ b/openid/yadis/constants.py @@ -9,5 +9,4 @@ YADIS_CONTENT_TYPE = 'application/xrds+xml' YADIS_ACCEPT_HEADER = generateAcceptHeader( ('text/html', 0.3), ('application/xhtml+xml', 0.5), - (YADIS_CONTENT_TYPE, 1.0), - ) + (YADIS_CONTENT_TYPE, 1.0), ) diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 21da85098ba376745b932a947e809117757d889f..af11b1036b17aa4ec2ec5fe3b4116151550658f0 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -9,6 +9,7 @@ from openid.yadis.constants import \ YADIS_HEADER_NAME, YADIS_CONTENT_TYPE, YADIS_ACCEPT_HEADER from openid.yadis.parsehtml import MetaNotFound, findHTMLMeta + class DiscoveryFailure(Exception): """Raised when a YADIS protocol error occurs in the discovery process""" identity_url = None @@ -17,6 +18,7 @@ class DiscoveryFailure(Exception): Exception.__init__(self, message) self.http_response = http_response + class DiscoveryResult(object): """Contains the result of performing Yadis discovery on a URI""" @@ -54,6 +56,7 @@ class DiscoveryResult(object): return (self.usedYadisLocation() or self.content_type == YADIS_CONTENT_TYPE) + def discover(uri): """Discover services for a given URI. @@ -73,7 +76,7 @@ def discover(uri): if resp.status not in (200, 206): raise DiscoveryFailure( 'HTTP Response status from identity URL host is not 200. ' - 'Got status %r' % (resp.status,), resp) + 'Got status %r' % (resp.status, ), resp) # Note the URL after following redirects result.normalized_uri = resp.final_url @@ -89,7 +92,7 @@ def discover(uri): if resp.status not in (200, 206): exc = DiscoveryFailure( 'HTTP Response status from Yadis host is not 200. ' - 'Got status %r' % (resp.status,), resp) + 'Got status %r' % (resp.status, ), resp) exc.identity_url = result.normalized_uri raise exc result.content_type = resp.headers.get('content-type') @@ -98,7 +101,6 @@ def discover(uri): return result - def whereIsYadis(resp): """Given a HTTPResponse, return the location of the Yadis document. @@ -116,7 +118,7 @@ def whereIsYadis(resp): # According to the spec, the content-type header must be an exact # match, or else we have to look for an indirection. if (content_type and - content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE): + content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE): return resp.final_url else: # Try the header diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index b42e2acf5c780452569534fe1a13e08efac08693..85e739afa9476d91a0db776d7565516008353614 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -16,7 +16,7 @@ __all__ = [ 'iterServices', 'expandService', 'expandServices', - ] +] import sys import random @@ -55,7 +55,13 @@ def parseXRDS(text): not contain an XRDS. """ try: - element = SafeElementTree.XML(text) + # lxml prefers to parse bytestrings, and occasionally chokes on a + # combination of text strings and declared XML encodings -- see + # https://github.com/necaris/python3-openid/issues/19 + # To avoid this, we ensure that the 'text' we're parsing is actually + # a bytestring + bytestring = text.encode('utf8') if isinstance(text, str) else text + element = SafeElementTree.XML(bytestring) except (SystemExit, MemoryError, AssertionError, ImportError): raise except Exception as why: @@ -95,6 +101,7 @@ def mkXRDSTag(t): """ return nsTag(XRDS_NS, t) + # Tags that are used in Yadis documents root_tag = mkXRDSTag('XRDS') service_tag = mkXRDTag('Service') @@ -200,12 +207,14 @@ class _Max(object): Should only be used as a singleton. Implemented for use as a priority value for when a priority is not specified. """ + def __lt__(self, other): return isinstance(other, self.__class__) def __eq__(self, other): return isinstance(other, self.__class__) + Max = _Max() @@ -260,15 +269,18 @@ def iterServices(xrd_tree): def sortedURIs(service_element): """Given a Service element, return a list of the contents of all URI tags in priority order.""" - return [uri_element.text for uri_element - in prioSort(service_element.findall(uri_tag))] + return [ + uri_element.text + for uri_element in prioSort(service_element.findall(uri_tag)) + ] def getTypeURIs(service_element): """Given a Service element, return a list of the contents of all Type tags""" - return [type_element.text for type_element - in service_element.findall(type_tag)] + return [ + type_element.text for type_element in service_element.findall(type_tag) + ] def expandService(service_element): diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index f195f4ed8b5463651f5d84e4722041dd75cd211a..4a7fbc675bf85692dcf28208ffa81d6847dd5a64 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -9,7 +9,7 @@ __all__ = [ 'IFilter', 'TransformFilterMaker', 'CompoundFilter', - ] +] from openid.yadis.etxrd import expandService import collections @@ -28,6 +28,7 @@ class BasicServiceEndpoint(object): The simplest kind of filter you can write implements fromBasicServiceEndpoint, which takes one of these objects. """ + def __init__(self, yadis_url, type_uris, uri, service_element): self.type_uris = type_uris self.yadis_url = yadis_url @@ -106,8 +107,8 @@ class TransformFilterMaker(object): # Create a basic endpoint object to represent this # yadis_url, Service, Type, URI combination - endpoint = BasicServiceEndpoint( - yadis_url, type_uris, uri, service_element) + endpoint = BasicServiceEndpoint(yadis_url, type_uris, uri, + service_element) e = self.applyFilters(endpoint) if e is not None: @@ -132,6 +133,7 @@ class CompoundFilter(object): """Create a new filter that applies a set of filters to an endpoint and collects their results. """ + def __init__(self, subfilters): self.subfilters = subfilters @@ -144,6 +146,7 @@ class CompoundFilter(object): subfilter.getServiceEndpoints(yadis_url, service_element)) return endpoints + # Exception raised when something is not able to be turned into a filter filter_type_error = TypeError( 'Expected a filter, an endpoint, a callable or a list of any of these.') diff --git a/openid/yadis/manager.py b/openid/yadis/manager.py index 4a698089534c2962a0494642cf40091685dd7ac1..9c9a042f6dd15c66249ffdc78c2a0406dda011c8 100644 --- a/openid/yadis/manager.py +++ b/openid/yadis/manager.py @@ -54,6 +54,7 @@ class YadisServiceManager(object): """Store this object in the session, by its session key.""" session[self.session_key] = self + class Discovery(object): """State management for discovery. diff --git a/openid/yadis/services.py b/openid/yadis/services.py index 5d0f9fb87919e1a7723e0e3161ddc626d896a933..2092a9cae77db1d4d4a21e0fa9b878ba209e605e 100644 --- a/openid/yadis/services.py +++ b/openid/yadis/services.py @@ -4,6 +4,7 @@ from openid.yadis.filters import mkFilter from openid.yadis.discover import discover, DiscoveryFailure from openid.yadis.etxrd import parseXRDS, iterServices, XRDSError + def getServiceEndpoints(input_url, flt=None): """Perform the Yadis protocol on the input URL and return an iterable of resulting endpoint objects. @@ -24,12 +25,13 @@ def getServiceEndpoints(input_url, flt=None): """ result = discover(input_url) try: - endpoints = applyFilter(result.normalized_uri, - result.response_text, flt) + endpoints = applyFilter(result.normalized_uri, result.response_text, + flt) except XRDSError as err: raise DiscoveryFailure(str(err), None) return (result.normalized_uri, endpoints) + def applyFilter(normalized_uri, xrd_data, flt=None): """Generate an iterable of endpoint objects given this input data, presumably from the result of performing the Yadis protocol. diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index a57d26f78bac7da8228f5c8dbca2c473f8c6d9d0..a88027f700727d66867011e4971ae8489b15f6e5 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -17,8 +17,8 @@ def identifierScheme(identifier): @returns: C{"XRI"} or C{"URI"} """ - if identifier.startswith('xri://') or ( - identifier and identifier[0] in XRI_AUTHORITIES): + if identifier.startswith('xri://') or (identifier and + identifier[0] in XRI_AUTHORITIES): return "XRI" else: return "URI" @@ -101,8 +101,7 @@ def rootAuthority(xri): else: # IRI reference. XXX: Can IRI authorities have segments? segments = authority.split('!') - segments = reduce(list.__add__, - [s.split('*') for s in segments]) + segments = reduce(list.__add__, [s.split('*') for s in segments]) root = segments[0] return XRI(root) diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index d012232e0ca83892dd4f9f9575079bf8858329bf..a7140294f09b0bbd5d88690c2a8503fa78397447 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -10,13 +10,14 @@ from openid.yadis.services import iterServices DEFAULT_PROXY = 'http://proxy.xri.net/' + class ProxyResolver(object): """Python interface to a remote XRI proxy resolver. """ + def __init__(self, proxy_url=DEFAULT_PROXY): self.proxy_url = proxy_url - def queryURL(self, xri, service_type=None): """Build a URL to query the proxy resolver. @@ -41,7 +42,7 @@ class ProxyResolver(object): # 11:13:42), then we could ask for application/xrd+xml instead, # which would give us a bit less to process. '_xrd_r': 'application/xrds+xml', - } + } if service_type: args['_xrd_t'] = service_type else: @@ -50,7 +51,6 @@ class ProxyResolver(object): query = _appendArgs(hxri, args) return query - def query(self, xri, service_types): """Resolve some services for an XRI. diff --git a/python3_openid.egg-info/PKG-INFO b/python3_openid.egg-info/PKG-INFO index e7a66b80723a0bf4745cad7baf86a9820d148fd0..c540d98259439fb7d977cb17f54f08c631df9c40 100644 --- a/python3_openid.egg-info/PKG-INFO +++ b/python3_openid.egg-info/PKG-INFO @@ -1,12 +1,12 @@ Metadata-Version: 1.1 Name: python3-openid -Version: 3.0.9 +Version: 3.1.0 Summary: OpenID support for modern servers and consumers. Home-page: http://github.com/necaris/python3-openid Author: Rami Chowdhury Author-email: rami.chowdhury@gmail.com License: UNKNOWN -Download-URL: http://github.com/necaris/python3-openid/tarball/v3.0.9 +Download-URL: http://github.com/necaris/python3-openid/tarball/v3.1.0 Description: This is a set of Python packages to support use of the OpenID decentralized identity system in your application, update to Python 3. Want to enable single sign-on for your web site? Use the openid.consumer diff --git a/python3_openid.egg-info/SOURCES.txt b/python3_openid.egg-info/SOURCES.txt index db6fea429213a05a19ebda1c916505f2af8fa72b..de53901724582304e1396cd3ebf2cc28049938f4 100644 --- a/python3_openid.egg-info/SOURCES.txt +++ b/python3_openid.egg-info/SOURCES.txt @@ -4,15 +4,13 @@ NEWS.md background-associations.txt setup.cfg setup.py -admin/builddiscover.py -admin/fixperms -admin/gettlds.py -admin/runtests -admin/setversion +admin/build_discover_data.py +admin/get_tlds.py +admin/next_version.py +admin/patch_version.py contrib/associate contrib/openid-parse contrib/upgrade-store-1.1-to-2.0 -contrib/upgrade-store-1.1-to-2.0~ examples/__init__.py examples/consumer.py examples/discover diff --git a/setup.cfg b/setup.cfg index bf52e6a907b11e9a1ae936764188ce3c82ec3e1a..50c20756e7700c7ee1937c1ecbf976e6289d519f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,12 @@ [sdist] force_manifest = 1 -formats = gztar,zip +formats = gztar, zip + +[bdist_wheel] +universal = 0 [egg_info] tag_build = -tag_svn_revision = 0 tag_date = 0 +tag_svn_revision = 0 diff --git a/setup.py b/setup.py index dcc219cee93babf97dea776e7d10aec45fbdbf18..02145e0ee048cde46bd1b3c0814fe987870b7981 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,15 @@ import sys -import os - from setuptools import setup import openid version = openid.__version__ -if 'sdist' in sys.argv: - # When building a source distribution, generate documentation - os.system('./admin/makedoc') +install_requires = [ + # Ensure that Python <= 3.3 uses an older version of `defusedxml`, which + # dropped compatibility in 0.5.0 + 'defusedxml' + ('<=0.4.1' if sys.version_info < (3, 4) else ''), +] setup( name='python3-openid', @@ -37,9 +37,7 @@ Includes example code and support for a variety of storage back-ends.''', maintainer_email='rami.chowdhury@gmail.com', download_url=('http://github.com/necaris/python3-openid/tarball' '/v{}'.format(version)), - install_requires=[ - 'defusedxml', - ], + install_requires=install_requires, classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", @@ -54,5 +52,4 @@ Includes example code and support for a variety of storage back-ends.''', "Topic :: Software Development :: Libraries :: Python Modules", ("Topic :: System :: Systems Administration :: " "Authentication/Directory"), - ] -) + ])