/[thuban]/branches/WIP-pyshapelib-bramz/test/postgissupport.py
ViewVC logotype

Diff of /branches/WIP-pyshapelib-bramz/test/postgissupport.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2057 by bh, Tue Feb 10 15:51:57 2004 UTC revision 2471 by bh, Thu Dec 16 14:19:21 2004 UTC
# Line 17  import time Line 17  import time
17  import popen2  import popen2
18  import shutil  import shutil
19  import traceback  import traceback
20    import re
21    
22  import support  import support
23    
# Line 184  class PostgreSQLServer: Line 185  class PostgreSQLServer:
185              raise RuntimeError("postmaster didn't start")              raise RuntimeError("postmaster didn't start")
186    
187      def is_running(self):      def is_running(self):
188          """Return true a postmaster process is running on self.dbdir          """Return whether a postmaster process is running on self.dbdir
189    
190          This method runs pg_ctl status on the dbdir so even if the          This method runs pg_ctl status on the dbdir and returns True if
191          object has just been created it is possible that this method          that command succeeds and False otherwise.
192          returns true if there's still a postmaster process running for  
193          self.dbdir.          Note that it is possible that this method returns true even if
194            the PostgreSQLServer instance has just been created and
195            createdb() has not been called yet.  This can happen, for
196            instance, if the server has been started manually for debugging
197            purposes after a test suite run.
198          """          """
199          return run_boolean_command(["pg_ctl", "-D", self.dbdir, "status"])          return run_boolean_command(["pg_ctl", "-D", self.dbdir, "status"])
200    
# Line 198  class PostgreSQLServer: Line 203  class PostgreSQLServer:
203          run_command(["pg_ctl", "-m", "fast", "-D", self.dbdir, "stop"],          run_command(["pg_ctl", "-m", "fast", "-D", self.dbdir, "stop"],
204                      os.path.join(self.dbdir, "pg_ctl-stop.log"))                      os.path.join(self.dbdir, "pg_ctl-stop.log"))
205    
206      def new_postgis_db(self, dbname, tables = None, reference_systems = None):      def new_postgis_db(self, dbname, tables = None, reference_systems = None,
207                           views = None):
208          """Create and return a new PostGISDatabase object using self as server          """Create and return a new PostGISDatabase object using self as server
209          """          """
210          db = PostGISDatabase(self, self.postgis_sql, dbname, tables = tables,          db = PostGISDatabase(self, self.postgis_sql, dbname, tables = tables,
211                               reference_systems = reference_systems)                               reference_systems = reference_systems,
212                                 views = views)
213          db.initdb()          db.initdb()
214          self.known_dbs[dbname] = db          self.known_dbs[dbname] = db
215          return db          return db
216    
217      def get_static_data_db(self, dbname, tables, reference_systems):      def get_static_data_db(self, dbname, tables, reference_systems, views):
218          """Return a PostGISDatabase for a database with the given static data          """Return a PostGISDatabase for a database with the given static data
219    
220          If no databasse of the name dbname exists, create a new one via          If no databasse of the name dbname exists, create a new one via
# Line 225  class PostgreSQLServer: Line 232  class PostgreSQLServer:
232          """          """
233          db = self.known_dbs.get(dbname)          db = self.known_dbs.get(dbname)
234          if db is not None:          if db is not None:
235              if db.has_data(tables, reference_systems):              if db.has_data(tables, reference_systems, views):
236                  return db                  return db
237              raise ValueError("PostGISDatabase named %r doesn't have tables %r"              raise ValueError("PostGISDatabase named %r doesn't have tables %r"
238                               % (dbname, tables))                               % (dbname, tables))
239          return self.new_postgis_db(dbname, tables, reference_systems)          return self.new_postgis_db(dbname, tables, reference_systems, views)
240    
241      def get_default_static_data_db(self):      def get_default_static_data_db(self):
242          dbname = "PostGISStaticTests"          dbname = "PostGISStaticTests"
# Line 237  class PostgreSQLServer: Line 244  class PostgreSQLServer:
244          tables = [          tables = [
245              # Direct copies of the shapefiles. The shapeids are exactly              # Direct copies of the shapefiles. The shapeids are exactly
246              # the same, except where changed with "gid_offset", of              # the same, except where changed with "gid_offset", of
247              # course              # course.  Note that the test implementation requires that
248                # all the landmard tables use an gid_offset of 1000.
249              ("landmarks", os.path.join("..", "Data", "iceland",              ("landmarks", os.path.join("..", "Data", "iceland",
250                                         "cultural_landmark-point.shp"),                                         "cultural_landmark-point.shp"),
251               [("gid_offset", 1000)]),               [("gid_offset", 1000)]),
# Line 251  class PostgreSQLServer: Line 259  class PostgreSQLServer:
259                                               "political.shp"),                                               "political.shp"),
260               [("force_wkt_type", "MULTIPOLYGON")]),               [("force_wkt_type", "MULTIPOLYGON")]),
261    
262              # Copy of landmarks but using an srid              # Copy of landmarks but using an srid != -1
263              ("landmarks_srid", os.path.join("..", "Data", "iceland",              ("landmarks_srid", os.path.join("..", "Data", "iceland",
264                                         "cultural_landmark-point.shp"),                                         "cultural_landmark-point.shp"),
265               [("gid_offset", 1000),               [("gid_offset", 1000),
266                ("srid", 1)]),                ("srid", 1)]),
267    
268                # Copy of landmarks with a gid column called "point_id" instead
269                # of "gid" and using an srid != -1.
270                ("landmarks_point_id", os.path.join("..", "Data", "iceland",
271                                                    "cultural_landmark-point.shp"),
272                 [("gid_offset", 1000),
273                  ("srid", 1),
274                  ("gid_column", "point_id")]),
275              ]              ]
276          return self.get_static_data_db(dbname, tables, srids)          views = [("v_landmarks", "SELECT * FROM landmarks_point_id")]
277            return self.get_static_data_db(dbname, tables, srids, views)
278    
279      def connection_params(self, user):      def connection_params(self, user):
280          """Return the connection parameters for the given user          """Return the connection parameters for the given user
# Line 287  class PostgreSQLServer: Line 304  class PostgreSQLServer:
304          return " ".join(params)          return " ".join(params)
305    
306      def execute_sql(self, dbname, user, sql):      def execute_sql(self, dbname, user, sql):
307          """Execute the sql statament          """Execute the sql statament and return a result for SELECT statements
308    
309          The user parameter us used as in connection_params. The dbname          The user parameter us used as in connection_params. The dbname
310          parameter must be the name of a database in the cluster.          parameter must be the name of a database in the cluster.  The
311            sql parameter is the SQL statement to execute as a string.  If
312            the string starts with 'select' (matched case insensitively) the
313            first row of the result will be returned.  Otherwise the return
314            value is None.
315          """          """
316          conn = psycopg.connect("dbname=%s " % dbname          conn = psycopg.connect("dbname=%s " % dbname
317                                 + self.connection_string(user))                                 + self.connection_string(user))
318          cursor = conn.cursor()          cursor = conn.cursor()
319          cursor.execute(sql)          cursor.execute(sql)
320            if sql.lower().startswith("select"):
321                row = cursor.fetchone()
322            else:
323                row = None
324          conn.commit()          conn.commit()
325          conn.close()          conn.close()
326            return row
327    
328        def server_version(self):
329            """Return the server version as a tuple (major, minor, patch)
330    
331            Each item in the tuple is an int.
332            """
333            result = self.execute_sql("template1", "admin", "SELECT version();")[0]
334            match = re.match(r"PostgreSQL (\d+\.\d+\.\d+)", result)
335            if match:
336                return tuple(map(int, match.group(1).split(".")))
337            else:
338                raise RutimeError("Cannot determine PostgreSQL server version"
339                                  " from %r" % result)
340    
341      def require_authentication(self, required):      def require_authentication(self, required):
342          """Switch authentication requirements on or off          """Switch authentication requirements on or off
# Line 309  class PostgreSQLServer: Line 348  class PostgreSQLServer:
348          corresponding call to switch it off again in the test case'          corresponding call to switch it off again in the test case'
349          tearDown method or in a finally: block.          tearDown method or in a finally: block.
350          """          """
351            # Starting with PostgreSQL 7.3 the pg_hba.conf file has an
352            # additional column with a username.  Query the server version
353            # and generate a file in the correct format.
354            if self.server_version() >= (7, 3):
355                user = "all"
356            else:
357                user = ""
358          if required:          if required:
359              contents = "local all password\n"              contents = "local all %s password\n" % user
360          else:          else:
361              contents = "local all trust\n"              contents = "local all %s trust\n" % user
362          f = open(os.path.join(self.dbdir, "pg_hba.conf"), "w")          f = open(os.path.join(self.dbdir, "pg_hba.conf"), "w")
363          f.write(contents)          f.write(contents)
364          f.close()          f.close()
# Line 336  class PostGISDatabase: Line 382  class PostGISDatabase:
382      """A PostGIS database in a PostgreSQLServer"""      """A PostGIS database in a PostgreSQLServer"""
383    
384      def __init__(self, server, postgis_sql, dbname, tables = None,      def __init__(self, server, postgis_sql, dbname, tables = None,
385                   reference_systems = ()):                   reference_systems = (), views = None):
386          """Initialize the PostGISDatabase          """Initialize the PostGISDatabase
387    
388          Parameters:          Parameters:
# Line 362  class PostGISDatabase: Line 408  class PostGISDatabase:
408                  (srid, params) pairs where srid is the srid defined by                  (srid, params) pairs where srid is the srid defined by
409                  the proj4 paramter string params.  The srid can be given                  the proj4 paramter string params.  The srid can be given
410                  as an extra parameter in the tables list.                  as an extra parameter in the tables list.
411    
412                views -- Optional description of views.  If given it should
413                    be a list of (viewname, select_stmt) pairs where
414                    viewname is the name of the view to be created and
415                    select_stmt is the select statement to use as the basis.
416                    The views will be created after the tables and may refer
417                    to them in the select_stmt.
418          """          """
419          self.server = server          self.server = server
420          self.postgis_sql = postgis_sql          self.postgis_sql = postgis_sql
421          self.dbname = dbname          self.dbname = dbname
422          self.tables = tables          self.tables = tables
423            self.views = views
424          if reference_systems:          if reference_systems:
425              self.reference_systems = reference_systems              self.reference_systems = reference_systems
426          else:          else:
# Line 422  class PostGISDatabase: Line 476  class PostGISDatabase:
476                  tablename, shapefile, kw = unpack(info)                  tablename, shapefile, kw = unpack(info)
477                  upload_shapefile(shapefile, self, tablename, **kw)                  upload_shapefile(shapefile, self, tablename, **kw)
478    
479      def has_data(self, tables, reference_systems):          if self.views is not None:
480                for viewname, select_stmt in self.views:
481                    self.server.execute_sql(self.dbname, "admin",
482                                            "CREATE VIEW %s AS %s" % (viewname,
483                                                                      select_stmt))
484                    self.server.execute_sql(self.dbname, "admin",
485                                            "GRANT SELECT ON %s TO PUBLIC;"
486                                            % viewname)
487    
488        def has_data(self, tables, reference_systems, views):
489          return (self.tables == tables          return (self.tables == tables
490                  and self.reference_systems == reference_systems)                  and self.reference_systems == reference_systems
491                    and self.views == views)
492    
493    
494  def find_postgis_sql():  def find_postgis_sql():
495      """Return the name of the postgis_sql file      """Return the name of the postgis_sql file
496    
497      A postgis installation usually has the postgis_sql file in      A postgis installation usually has the postgis_sql file in
498      PostgreSQL's datadir (i.e. the directory where PostgreSQL keeps      PostgreSQL's $datadir (i.e. the directory where PostgreSQL keeps
499      static files, not the directory containing the databases).      static files, not the directory containing the databases).
500      Unfortunately there's no way to determine the name of this directory      Unfortunately there's no way to determine the name of this directory
501      with pg_config so we assume here that it's      with pg_config so we assume here that it's
502      $bindir/../share/postgresql/.      $bindir/../share/postgresql/.
503    
504        Furthermore, different versions of postgis place the file in
505        slightly different locations.  For instance:
506    
507          postgis 0.7.5        $datadir/contrib/postgis.sql
508          postgis 0.8.1        $datadir/postgis.sql
509    
510        To support both versions, we look in both places and return the
511        first one found (looking under contrib first).  If the file is not
512        found the return value is None.
513      """      """
514      bindir = run_config_script("pg_config --bindir").strip()      bindir = run_config_script("pg_config --bindir").strip()
515      return os.path.join(bindir, "..", "share", "postgresql",      datadir = os.path.join(bindir, "..", "share", "postgresql")
516                          "contrib", "postgis.sql")      for filename in [os.path.join(datadir, "contrib", "postgis.sql"),
517                         os.path.join(datadir, "postgis.sql")]:
518            if os.path.exists(filename):
519                return filename
520    
521    
522  _postgres_server = None  _postgres_server = None
523  def get_test_server():  def get_test_server():
# Line 529  def skip_if_addgeometrycolumn_does_not_u Line 607  def skip_if_addgeometrycolumn_does_not_u
607    
608      The test performed by this function is a bit simplistic because it      The test performed by this function is a bit simplistic because it
609      only tests whether the string 'quote_ident' occurs anywhere in the      only tests whether the string 'quote_ident' occurs anywhere in the
610      postgis.sql file. This will hopefully works because when this was      postgis.sql file. This will hopefully work because when this was
611      fixed in postgis CVS AddGeometryColumn was the first function to use      fixed in postgis CVS AddGeometryColumn was the first function to use
612      quote_ident.      quote_ident.
613      """      """
# Line 573  wkt_converter = { Line 651  wkt_converter = {
651      }      }
652    
653  def upload_shapefile(filename, db, tablename, force_wkt_type = None,  def upload_shapefile(filename, db, tablename, force_wkt_type = None,
654                       gid_offset = 0, srid = -1):                       gid_offset = 0, gid_column = "gid", srid = -1):
655      """Upload a shapefile into a new database table      """Upload a shapefile into a new database table
656    
657      Parameters:      Parameters:
# Line 592  def upload_shapefile(filename, db, table Line 670  def upload_shapefile(filename, db, table
670      gid_offset -- A number to add to the shapeid to get the value for      gid_offset -- A number to add to the shapeid to get the value for
671                  the gid column (default 0)                  the gid column (default 0)
672    
673        gid_column -- The name of the column with the shape ids.  Default
674                      'gid'.  If None, no gid column will be created.  The
675                      name is directly used in SQL statements, so if it
676                      contains unusualy characters the caller should provide
677                      a suitable quoted string.
678    
679      srid -- The srid of the spatial references system used by the table      srid -- The srid of the spatial references system used by the table
680              and the data              and the data
681      """      """
# Line 618  def upload_shapefile(filename, db, table Line 702  def upload_shapefile(filename, db, table
702                 dbflib.FTInteger: "INTEGER",                 dbflib.FTInteger: "INTEGER",
703                 dbflib.FTDouble: "DOUBLE PRECISION"}                 dbflib.FTDouble: "DOUBLE PRECISION"}
704    
705      insert_formats = ["%(gid)s"]      insert_formats = []
706      fields = ["gid INT"]      if gid_column:
707            insert_formats.append("%(gid)s")
708    
709        fields = []
710        fields_decl = []
711        if gid_column:
712            fields.append(gid_column)
713            fields_decl.append("%s INT" % gid_column)
714      for i in range(dbf.field_count()):      for i in range(dbf.field_count()):
715          ftype, name, width, prec = dbf.field_info(i)          ftype, name, width, prec = dbf.field_info(i)
716          fields.append("%s %s" % (name, typemap[ftype]))          fields.append(name)
717            fields_decl.append("%s %s" % (name, typemap[ftype]))
718          insert_formats.append("%%(%s)s" % name)          insert_formats.append("%%(%s)s" % name)
719      stmt = "CREATE TABLE %s (\n    %s\n);" % (tablename,      stmt = "CREATE TABLE %s (\n    %s\n);" % (tablename,
720                                                ",\n    ".join(fields))                                                ",\n    ".join(fields_decl))
721      cursor.execute(stmt)      cursor.execute(stmt)
722      #print stmt      #print stmt
723    
# Line 638  def upload_shapefile(filename, db, table Line 730  def upload_shapefile(filename, db, table
730      cursor.execute("select AddGeometryColumn('%(dbname)s',"      cursor.execute("select AddGeometryColumn('%(dbname)s',"
731                     "'%(tablename)s', 'the_geom', %(srid)d, '%(wkttype)s', 2);"                     "'%(tablename)s', 'the_geom', %(srid)d, '%(wkttype)s', 2);"
732                     % locals())                     % locals())
733        fields.append("the_geom")
734      insert_formats.append("GeometryFromText(%(the_geom)s, %(srid)d)")      insert_formats.append("GeometryFromText(%(the_geom)s, %(srid)d)")
735    
736      insert = ("INSERT INTO %s VALUES (%s)"      insert = ("INSERT INTO %s (%s) VALUES (%s)"
737                % (tablename, ", ".join(insert_formats)))                % (tablename, ", ".join(fields), ", ".join(insert_formats)))
738    
739      for i in range(numshapes):      for i in range(numshapes):
740          data = dbf.read_record(i)          data = dbf.read_record(i)
741          data["tablename"] = tablename          data["tablename"] = tablename
742          data["gid"] = i + gid_offset          if gid_column:
743                data["gid"] = i + gid_offset
744          data["srid"] = srid          data["srid"] = srid
745          data["the_geom"] = convert(shp.read_object(i).vertices())          data["the_geom"] = convert(shp.read_object(i).vertices())
746          #print insert % data          #print insert % data

Legend:
Removed from v.2057  
changed lines
  Added in v.2471

[email protected]
ViewVC Help
Powered by ViewVC 1.1.26