/[thuban]/branches/WIP-pyshapelib-bramz/Thuban/Model/transientdb.py
ViewVC logotype

Diff of /branches/WIP-pyshapelib-bramz/Thuban/Model/transientdb.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 765 by bh, Tue Apr 29 12:42:14 2003 UTC revision 844 by bh, Tue May 6 18:05:04 2003 UTC
# Line 53  class TransientDatabase: Line 53  class TransientDatabase:
53          self.close()          self.close()
54    
55      def close(self):      def close(self):
56          if self.conn:          if self.conn is not None:
57              self.conn.close()              self.conn.close()
58                self.conn = None
59    
60      def new_table_name(self):      def new_table_name(self):
61          self.num_tables += 1          self.num_tables += 1
# Line 84  class ColumnReference: Line 85  class ColumnReference:
85          self.internal_name = internal_name          self.internal_name = internal_name
86    
87    
88  class TransientTableBase:  class TransientTableBase(table.OldTableInterfaceMixin):
89    
90      """Base class for tables in the transient database"""      """Base class for tables in the transient database"""
91    
# Line 95  class TransientTableBase: Line 96  class TransientTableBase:
96          self.indexed_columns = {}          self.indexed_columns = {}
97          self.read_record_cursor = None          self.read_record_cursor = None
98          self.read_record_last_row = None          self.read_record_last_row = None
99            self.read_record_last_result = None
100    
101      def create(self, columns):      def create(self, columns):
102          self.columns = columns          self.columns = columns
# Line 102  class TransientTableBase: Line 104  class TransientTableBase:
104          self.orig_names = []          self.orig_names = []
105          self.internal_to_orig = {}          self.internal_to_orig = {}
106          self.orig_to_internal = {}          self.orig_to_internal = {}
107            self.column_map = {}
108    
109          # Create the column objects and fill various maps and lists          # Create the column objects and fill various maps and lists
110          for col in self.columns:          for index in range(len(self.columns)):
111                col = self.columns[index]
112              self.name_to_column[col.name] = col              self.name_to_column[col.name] = col
113              self.orig_names.append(col.name)              self.orig_names.append(col.name)
114              self.internal_to_orig[col.internal_name] = col.name              self.internal_to_orig[col.internal_name] = col.name
115              self.orig_to_internal[col.name] = col.internal_name              self.orig_to_internal[col.name] = col.internal_name
116                self.column_map[col.name] = col
117                self.column_map[index] = col
118    
119          # Build the CREATE TABLE statement and create the table in the          # Build the CREATE TABLE statement and create the table in the
120          # database          # database
121          table_types = []          table_types = ["id INTEGER PRIMARY KEY"]
122          for col in self.columns:          for col in self.columns:
123              table_types.append("%s %s" % (col.internal_name,              table_types.append("%s %s" % (col.internal_name,
124                                            sql_type_map[col.type]))                                            sql_type_map[col.type]))
# Line 136  class TransientTableBase: Line 142  class TransientTableBase:
142              self.db.execute(stmt)              self.db.execute(stmt)
143              self.indexed_columns[column] = 1              self.indexed_columns[column] = 1
144    
145      def field_count(self):      def NumColumns(self):
146          return len(self.columns)          return len(self.columns)
147    
148      def field_info(self, i):      def NumRows(self):
         col = self.columns[i]  
         return col.type, col.name, 0, 0  
   
     def field_info_by_name(self, name):  
         for col in self.columns:  
             if col.name == name:  
                 return col.type, col.name, 0, 0  
         else:  
             return None  
   
     def record_count(self):  
149          result = self.db.execute("SELECT count(*) FROM %s;" % self.tablename)          result = self.db.execute("SELECT count(*) FROM %s;" % self.tablename)
150          return int(result[0])          return int(result[0])
151    
152      def read_record(self, index):      def Columns(self):
153            return self.columns
154    
155        def Column(self, col):
156            return self.column_map[col]
157    
158        def HasColumn(self, col):
159            """Return whether the table has a column with the given name or index
160            """
161            return self.column_map.has_key(col)
162    
163        def ReadRowAsDict(self, index):
164          if self.read_record_cursor is None or index <self.read_record_last_row:          if self.read_record_cursor is None or index <self.read_record_last_row:
165              stmt = ("SELECT %s FROM %s;"              stmt = ("SELECT %s FROM %s;"
166                      % (", ".join([c.internal_name for c in self.columns]),                      % (", ".join([c.internal_name for c in self.columns]),
# Line 162  class TransientTableBase: Line 168  class TransientTableBase:
168              self.read_record_cursor = self.db.cursor()              self.read_record_cursor = self.db.cursor()
169              self.read_record_cursor.execute(stmt)              self.read_record_cursor.execute(stmt)
170              self.read_record_last_row = -1              self.read_record_last_row = -1
171          for i in range(index - self.read_record_last_row):              self.read_record_last_result = None
172              result = self.read_record_cursor.fetchone()  
173            # Now we should have a cursor at a position less than or equal
174            # to the index so the following if statement will always set
175            # result to a suitable value
176            assert index >= self.read_record_last_row
177    
178            if index == self.read_record_last_row:
179                result = self.read_record_last_result
180            else:
181                for i in range(index - self.read_record_last_row):
182                    result = self.read_record_cursor.fetchone()
183                    self.read_record_last_result = result
184          self.read_record_last_row = index          self.read_record_last_row = index
185          result = dict(zip(self.orig_names, result))          return dict(zip(self.orig_names, result))
         return result  
186    
187      def field_range(self, colname):      def ValueRange(self, col):
188          col = self.name_to_column[colname]          col = self.column_map[col]
189          iname = col.internal_name          iname = col.internal_name
190          min, max = self.db.execute("SELECT min(%s), max(%s) FROM %s;"          min, max = self.db.execute("SELECT min(%s), max(%s) FROM %s;"
191                                     % (iname, iname, self.tablename))                                     % (iname, iname, self.tablename))
192          converter = type_converter_map[col.type]          converter = type_converter_map[col.type]
193          return ((converter(min), None), (converter(max), None))          return (converter(min), converter(max))
194    
195      def GetUniqueValues(self, colname):      def UniqueValues(self, col):
196          iname = self.orig_to_internal[colname]          iname = self.column_map[col].internal_name
197          cursor = self.db.cursor()          cursor = self.db.cursor()
198          cursor.execute("SELECT %s FROM %s GROUP BY %s;"          cursor.execute("SELECT %s FROM %s GROUP BY %s;"
199                         % (iname, self.tablename, iname))                         % (iname, self.tablename, iname))
# Line 189  class TransientTableBase: Line 205  class TransientTableBase:
205              result.append(row[0])              result.append(row[0])
206          return result          return result
207    
208        def SimpleQuery(self, left, comparison, right):
209            """Return the indices of all rows that matching a condition.
210    
211            Parameters:
212               left -- The column object for the left side of the comparison
213    
214               comparison -- The comparison operator as a string. It must be
215                             one of '==', '!=', '<', '<=', '>=', '>'
216    
217               right -- The right hand side of the comparison. It must be
218                        either a column object or a value, i.e. a string,
219                        int or float.
220    
221            The return value is a sorted list of the indices of the rows
222            where the condition is true.
223            """
224            if comparison not in ("==", "!=", "<", "<=", ">=", ">"):
225                raise ValueError("Comparison operator %r not allowed" % comparison)
226    
227            if hasattr(right, "internal_name"):
228                right_template = right.internal_name
229                params = ()
230            else:
231                right_template = "%s"
232                params = (right,)
233    
234            query = "SELECT id FROM %s WHERE %s %s %s ORDER BY id;" \
235                    % (self.tablename, left.internal_name, comparison,
236                       right_template)
237    
238            cursor = self.db.cursor()
239            cursor.execute(query, params)
240            result = []
241            while 1:
242                row = cursor.fetchone()
243                if row is None:
244                    break
245                result.append(row[0])
246            return result
247    
248    
249  class TransientTable(TransientTableBase):  class TransientTable(TransientTableBase):
250    
# Line 205  class TransientTable(TransientTableBase) Line 261  class TransientTable(TransientTableBase)
261    
262      def create(self, table):      def create(self, table):
263          columns = []          columns = []
264          for i in range(table.field_count()):          for col in table.Columns():
265              type, name = table.field_info(i)[:2]              columns.append(ColumnReference(col.name, col.type,
             columns.append(ColumnReference(name, type,  
266                                             self.db.new_column_name()))                                             self.db.new_column_name()))
267          TransientTableBase.create(self, columns)          TransientTableBase.create(self, columns)
268    
269          # copy the input table to the transient db          # copy the input table to the transient db
270          insert_template = "INSERT INTO %s (%s) VALUES (%s);" \  
271            # A key to insert to use for the formatting of the insert
272            # statement. The key must not be equal to any of the column
273            # names so we construct one by building a string of x's that is
274            # longer than any of the column names
275            id_key = max([len(col.name) for col in self.columns]) * "x"
276    
277            insert_template = "INSERT INTO %s (id, %s) VALUES (%%(%s)s, %s);" \
278                                 % (self.tablename,                                 % (self.tablename,
279                                    ", ".join([col.internal_name                                    ", ".join([col.internal_name
280                                               for col in self.columns]),                                               for col in self.columns]),
281                                      id_key,
282                                    ", ".join(["%%(%s)s" % col.name                                    ", ".join(["%%(%s)s" % col.name
283                                               for col in self.columns]))                                               for col in self.columns]))
284          cursor = self.db.cursor()          cursor = self.db.cursor()
285          for i in range(table.record_count()):          for i in range(table.NumRows()):
286              cursor.execute(insert_template, table.read_record(i))              row = table.ReadRowAsDict(i)
287                row[id_key] = i
288                cursor.execute(insert_template, row)
289          self.db.conn.commit()          self.db.conn.commit()
290    
291    
# Line 267  class TransientJoinedTable(TransientTabl Line 332  class TransientJoinedTable(TransientTabl
332          columns = []          columns = []
333          for col in self.left_table.columns + self.right_table.columns:          for col in self.left_table.columns + self.right_table.columns:
334              if col.name in visited:              if col.name in visited:
335                    # We can't allow multiple columns with the same original
336                    # name, so omit this one. FIXME: There should be a
337                    # better solution.
338                  continue                  continue
339              columns.append(col)              columns.append(col)
340          TransientTableBase.create(self, columns)          TransientTableBase.create(self, columns)
341    
342          # Copy the joined data to the table.          # Copy the joined data to the table.
343          internal_names = [col.internal_name for col in self.columns]          internal_names = [col.internal_name for col in self.columns]
344          stmt = "INSERT INTO %s (%s) SELECT %s FROM %s JOIN %s ON %s = %s;" \          stmt = ("INSERT INTO %s (id, %s) SELECT %s.id, %s FROM %s"
345                 % (self.tablename,                  " JOIN %s ON %s = %s;"
346                    ", ".join(internal_names),                  % (self.tablename,
347                    ", ".join(internal_names),                     ", ".join(internal_names),
348                    self.left_table.tablename,                     self.left_table.tablename,
349                    self.right_table.tablename,                     ", ".join(internal_names),
350                    self.orig_to_internal[self.left_field],                     self.left_table.tablename,
351                    self.orig_to_internal[self.right_field])                     self.right_table.tablename,
352                       self.orig_to_internal[self.left_field],
353                       self.orig_to_internal[self.right_field]))
354          self.db.execute(stmt)          self.db.execute(stmt)
355    
356    
357  class AutoTransientTable:  class AutoTransientTable(table.OldTableInterfaceMixin):
358    
359      """Table that copies data to a transient table on demand.      """Table that copies data to a transient table on demand.
360    
# Line 297  class AutoTransientTable: Line 367  class AutoTransientTable:
367          self.table = table          self.table = table
368          self.t_table = None          self.t_table = None
369    
370      def record_count(self):      def Columns(self):
371          """Return the number of records"""          return self.table.Columns()
         return self.table.record_count()  
   
     def field_count(self):  
         """Return the number of fields in a record"""  
         return self.table.field_count()  
372    
373      def field_info(self, field):      def Column(self, col):
374          """Return a tuple (type, name, width, prec) for the field no. field          return self.table.Column(col)
375    
376          type is the data type of the field, name the name, width the      def HasColumn(self, col):
377          field width in characters and prec the decimal precision.          """Return whether the table has a column with the given name or index
378          """          """
379          info = self.table.field_info(field)          return self.table.HasColumn(col)
380          if info:  
381              info = info[:2] + (0, 0)      def NumRows(self):
382          return info          return self.table.NumRows()
383    
384      def field_info_by_name(self, fieldName):      def NumColumns(self):
385          info = self.table.field_info_by_name(fieldName)          return self.table.NumColumns()
         if info:  
             info = info[:2] + (0, 0)  
         return info  
386    
387      def read_record(self, record):      def ReadRowAsDict(self, record):
388          """Return the record no. record as a dict mapping field names to values          """Return the record no. record as a dict mapping field names to values
389          """          """
390          if self.t_table is not None:          if self.t_table is not None:
391              return self.t_table.read_record(record)              return self.t_table.ReadRowAsDict(record)
392          else:          else:
393              return self.table.read_record(record)              return self.table.ReadRowAsDict(record)
   
     def write_record(self, record, values):  
         raise NotImplementedError  
394    
395      def copy_to_transient(self):      def copy_to_transient(self):
396          """Internal: Create a transient table and copy the data into it"""          """Internal: Create a transient table and copy the data into it"""
# Line 345  class AutoTransientTable: Line 404  class AutoTransientTable:
404              self.copy_to_transient()              self.copy_to_transient()
405          return self.t_table          return self.t_table
406    
407      def field_range(self, colname):      def ValueRange(self, col):
408            if self.t_table is None:
409                self.copy_to_transient()
410            return self.t_table.ValueRange(col)
411    
412        def UniqueValues(self, col):
413          if self.t_table is None:          if self.t_table is None:
414              self.copy_to_transient()              self.copy_to_transient()
415          return self.t_table.field_range(colname)          return self.t_table.UniqueValues(col)
416    
417      def GetUniqueValues(self, colname):      def SimpleQuery(self, left, comparison, right):
418          if self.t_table is None:          if self.t_table is None:
419              self.copy_to_transient()              self.copy_to_transient()
420          return self.t_table.GetUniqueValues(colname)          # Make sure to use the column object of the transient table. The
421            # left argument is always a column object so we can just ask the
422            # t_table for the right object.
423            return self.t_table.SimpleQuery(self.t_table.Column(left.name),
424                                            comparison, right)

Legend:
Removed from v.765  
changed lines
  Added in v.844

[email protected]
ViewVC Help
Powered by ViewVC 1.1.26