aboutsummaryrefslogtreecommitdiffstats
path: root/lib/python2.7/site-packages/SQLAlchemy-0.7.0-py2.7-linux-x86_64.egg/sqlalchemy/util/compat.py
blob: 0fb004500878517a7fe5bb73e3a762b3d754c9fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# util/compat.py
# Copyright (C) 2005-2011 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

"""Handle Python version/platform incompatibilities."""

import sys

# Py2K
import __builtin__
# end Py2K

try:
    import threading
except ImportError:
    import dummy_threading as threading

py32 = sys.version_info >= (3, 2)
py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
jython = sys.platform.startswith('java')
pypy = hasattr(sys, 'pypy_version_info')
win32 = sys.platform.startswith('win')

if py3k:
    set_types = set
elif sys.version_info < (2, 6):
    import sets
    set_types = set, sets.Set
else:
    # 2.6 deprecates sets.Set, but we still need to be able to detect them
    # in user code and as return values from DB-APIs
    ignore = ('ignore', None, DeprecationWarning, None, 0)
    import warnings
    try:
        warnings.filters.insert(0, ignore)
    except Exception:
        import sets
    else:
        import sets
        warnings.filters.remove(ignore)

    set_types = set, sets.Set

if py3k:
    import pickle
else:
    try:
        import cPickle as pickle
    except ImportError:
        import pickle

# a controversial feature, required by MySQLdb currently
def buffer(x):
    return x 

# Py2K
buffer = getattr(__builtin__, 'buffer', buffer)
# end Py2K

try:
    from functools import update_wrapper
except ImportError:
    def update_wrapper(wrapper, wrapped,
                       assigned=('__doc__', '__module__', '__name__'),
                       updated=('__dict__',)):
        for attr in assigned:
            setattr(wrapper, attr, getattr(wrapped, attr))
        for attr in updated:
            getattr(wrapper, attr).update(getattr(wrapped, attr, ()))
        return wrapper

try:
    from functools import partial
except ImportError:
    def partial(func, *args, **keywords):
        def newfunc(*fargs, **fkeywords):
            newkeywords = keywords.copy()
            newkeywords.update(fkeywords)
            return func(*(args + fargs), **newkeywords)
        return newfunc


if py3k:
    # they're bringing it back in 3.2.  brilliant !
    def callable(fn):
        return hasattr(fn, '__call__')
    def cmp(a, b):
        return (a > b) - (a < b)

    from functools import reduce
else:
    callable = __builtin__.callable
    cmp = __builtin__.cmp
    reduce = __builtin__.reduce

try:
    from collections import defaultdict
except ImportError:
    class defaultdict(dict):
        def __init__(self, default_factory=None, *a, **kw):
            if (default_factory is not None and
                not hasattr(default_factory, '__call__')):
                raise TypeError('first argument must be callable')
            dict.__init__(self, *a, **kw)
            self.default_factory = default_factory
        def __getitem__(self, key):
            try:
                return dict.__getitem__(self, key)
            except KeyError:
                return self.__missing__(key)
        def __missing__(self, key):
            if self.default_factory is None:
                raise KeyError(key)
            self[key] = value = self.default_factory()
            return value
        def __reduce__(self):
            if self.default_factory is None:
                args = tuple()
            else:
                args = self.default_factory,
            return type(self), args, None, None, self.iteritems()
        def copy(self):
            return self.__copy__()
        def __copy__(self):
            return type(self)(self.default_factory, self)
        def __deepcopy__(self, memo):
            import copy
            return type(self)(self.default_factory,
                              copy.deepcopy(self.items()))
        def __repr__(self):
            return 'defaultdict(%s, %s)' % (self.default_factory,
                                            dict.__repr__(self))


# find or create a dict implementation that supports __missing__
class _probe(dict):
    def __missing__(self, key):
        return 1

try:
    try:
        _probe()['missing']
        py25_dict = dict
    except KeyError:
        class py25_dict(dict):
            def __getitem__(self, key):
                try:
                    return dict.__getitem__(self, key)
                except KeyError:
                    try:
                        missing = self.__missing__
                    except AttributeError:
                        raise KeyError(key)
                    else:
                        return missing(key)
finally:
    del _probe


try:
    import hashlib
    _md5 = hashlib.md5
except ImportError:
    import md5
    _md5 = md5.new

def md5_hex(x):
    # Py3K
    #x = x.encode('utf-8')
    m = _md5()
    m.update(x)
    return m.hexdigest()

import time
if win32 or jython:
    time_func = time.clock
else:
    time_func = time.time 

if sys.version_info >= (2, 5):
    def decode_slice(slc):
        """decode a slice object as sent to __getitem__.

        takes into account the 2.5 __index__() method, basically.

        """
        ret = []
        for x in slc.start, slc.stop, slc.step:
            if hasattr(x, '__index__'):
                x = x.__index__()
            ret.append(x)
        return tuple(ret)
else:
    def decode_slice(slc):
        return (slc.start, slc.stop, slc.step)

if sys.version_info >= (2, 6):
    from operator import attrgetter as dottedgetter
else:
    def dottedgetter(attr):
        def g(obj):
            for name in attr.split("."):
                obj = getattr(obj, name)
            return obj
        return g


import decimal