aboutsummaryrefslogtreecommitdiffstats
path: root/lib/python2.7/site-packages/SQLAlchemy-0.7.0-py2.7-linux-x86_64.egg/sqlalchemy/engine/ddl.py
blob: 79958baae4065054c6d080adcefe41446fb47862 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# engine/ddl.py
# Copyright (C) 2009-2011 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

"""Routines to handle CREATE/DROP workflow."""

from sqlalchemy import engine, schema
from sqlalchemy.sql import util as sql_util


class DDLBase(schema.SchemaVisitor):
    def __init__(self, connection):
        self.connection = connection

class SchemaGenerator(DDLBase):
    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
        super(SchemaGenerator, self).__init__(connection, **kwargs)
        self.checkfirst = checkfirst
        self.tables = tables and set(tables) or None
        self.preparer = dialect.identifier_preparer
        self.dialect = dialect

    def _can_create_table(self, table):
        self.dialect.validate_identifier(table.name)
        if table.schema:
            self.dialect.validate_identifier(table.schema)
        return not self.checkfirst or \
                not self.dialect.has_table(self.connection, 
                                    table.name, schema=table.schema)

    def _can_create_sequence(self, sequence):
        return self.dialect.supports_sequences and \
            (
                (not self.dialect.sequences_optional or
                 not sequence.optional) and
                 (
                 not self.checkfirst or
                 not self.dialect.has_sequence(
                            self.connection, 
                            sequence.name, 
                            schema=sequence.schema)
                 )
            )

    def visit_metadata(self, metadata):
        if self.tables:
            tables = self.tables
        else:
            tables = metadata.tables.values()
        collection = [t for t in sql_util.sort_tables(tables) 
                        if self._can_create_table(t)]
        seq_coll = [s for s in metadata._sequences.values() 
                        if s.column is None and self._can_create_sequence(s)]

        metadata.dispatch.before_create(metadata, self.connection,
                                    tables=collection,
                                    checkfirst=self.checkfirst)

        for seq in seq_coll:
            self.traverse_single(seq, create_ok=True)

        for table in collection:
            self.traverse_single(table, create_ok=True)

        metadata.dispatch.after_create(metadata, self.connection,
                                    tables=collection,
                                    checkfirst=self.checkfirst)

    def visit_table(self, table, create_ok=False):
        if not create_ok and not self._can_create_table(table):
            return

        table.dispatch.before_create(table, self.connection,
                                        checkfirst=self.checkfirst)

        for column in table.columns:
            if column.default is not None:
                self.traverse_single(column.default)

        self.connection.execute(schema.CreateTable(table))

        if hasattr(table, 'indexes'):
            for index in table.indexes:
                self.traverse_single(index)

        table.dispatch.after_create(table, self.connection,
                                        checkfirst=self.checkfirst)

    def visit_sequence(self, sequence, create_ok=False):
        if not create_ok and not self._can_create_sequence(sequence):
            return 
        self.connection.execute(schema.CreateSequence(sequence))

    def visit_index(self, index):
        self.connection.execute(schema.CreateIndex(index))


class SchemaDropper(DDLBase):
    def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
        super(SchemaDropper, self).__init__(connection, **kwargs)
        self.checkfirst = checkfirst
        self.tables = tables
        self.preparer = dialect.identifier_preparer
        self.dialect = dialect

    def visit_metadata(self, metadata):
        if self.tables:
            tables = self.tables
        else:
            tables = metadata.tables.values()
        collection = [t for t in reversed(sql_util.sort_tables(tables)) 
                                if self._can_drop_table(t)]
        seq_coll = [s for s in metadata._sequences.values() 
                                if s.column is None and self._can_drop_sequence(s)]

        metadata.dispatch.before_drop(metadata, self.connection,
                                            tables=collection,
                                            checkfirst=self.checkfirst)

        for table in collection:
            self.traverse_single(table, drop_ok=True)

        for seq in seq_coll:
            self.traverse_single(seq, drop_ok=True)

        metadata.dispatch.after_drop(metadata, self.connection,
                                            tables=collection,
                                            checkfirst=self.checkfirst)

    def _can_drop_table(self, table):
        self.dialect.validate_identifier(table.name)
        if table.schema:
            self.dialect.validate_identifier(table.schema)
        return not self.checkfirst or self.dialect.has_table(self.connection, 
                                            table.name, schema=table.schema)

    def _can_drop_sequence(self, sequence):
        return self.dialect.supports_sequences and \
            ((not self.dialect.sequences_optional or
                 not sequence.optional) and
                (not self.checkfirst or
                 self.dialect.has_sequence(
                                self.connection, 
                                sequence.name, 
                                schema=sequence.schema))
            )

    def visit_index(self, index):
        self.connection.execute(schema.DropIndex(index))

    def visit_table(self, table, drop_ok=False):
        if not drop_ok and not self._can_drop_table(table):
            return

        table.dispatch.before_drop(table, self.connection,
                                    checkfirst=self.checkfirst)

        for column in table.columns:
            if column.default is not None:
                self.traverse_single(column.default)

        self.connection.execute(schema.DropTable(table))

        table.dispatch.after_drop(table, self.connection,
                                        checkfirst=self.checkfirst)

    def visit_sequence(self, sequence, drop_ok=False):
        if not drop_ok and not self._can_drop_sequence(sequence):
            return
        self.connection.execute(schema.DropSequence(sequence))