/[thuban]/branches/WIP-pyshapelib-bramz/test/postgissupport.py
ViewVC logotype

Diff of /branches/WIP-pyshapelib-bramz/test/postgissupport.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 1605 by bh, Tue Aug 19 11:00:40 2003 UTC revision 2096 by bh, Thu Mar 11 13:50:53 2004 UTC
# Line 1  Line 1 
1  # Copyright (C) 2003 by Intevation GmbH  # Copyright (C) 2003, 2004 by Intevation GmbH
2  # Authors:  # Authors:
3  # Bernhard Herzog <[email protected]>  # Bernhard Herzog <[email protected]>
4  #  #
# Line 115  class PostgreSQLServer: Line 115  class PostgreSQLServer:
115          self.socket_dir = socket_dir          self.socket_dir = socket_dir
116    
117          # For the client side the socket directory can be used as the          # For the client side the socket directory can be used as the
118          # host the name starts with a slash.          # host if the name starts with a slash.
119          self.host = os.path.abspath(socket_dir)          self.host = os.path.abspath(socket_dir)
120    
121            # name and password for the admin and an unprivileged user
122            self.admin_name = "postgres"
123            self.admin_password = "postgres"
124            self.user_name = "observer"
125            self.user_password = "telescope"
126    
127          # Map db names to db objects          # Map db names to db objects
128          self.known_dbs = {}          self.known_dbs = {}
129    
# Line 135  class PostgreSQLServer: Line 141  class PostgreSQLServer:
141              shutil.rmtree(self.dbdir)              shutil.rmtree(self.dbdir)
142          os.mkdir(self.dbdir)          os.mkdir(self.dbdir)
143    
144          run_command(["initdb", self.dbdir],          run_command(["initdb", "-D", self.dbdir, "-U", self.admin_name],
145                      os.path.join(self.dbdir, "initdb.log"))                      os.path.join(self.dbdir, "initdb.log"))
146    
147          extra_opts = "-p %d" % self.port          extra_opts = "-p %d" % self.port
# Line 150  class PostgreSQLServer: Line 156  class PostgreSQLServer:
156          # server ourselves          # server ourselves
157          self.wait_for_postmaster()          self.wait_for_postmaster()
158    
159            self.alter_user(self.admin_name, self.admin_password)
160            self.create_user(self.user_name, self.user_password)
161    
162      def wait_for_postmaster(self):      def wait_for_postmaster(self):
163          """Return when the database server is running          """Return when the database server is running
164    
# Line 161  class PostgreSQLServer: Line 170  class PostgreSQLServer:
170          while count < max_count:          while count < max_count:
171              try:              try:
172                  run_command(["psql", "-l", "-p", str(self.port),                  run_command(["psql", "-l", "-p", str(self.port),
173                               "-h", self.host],                               "-h", self.host, "-U", self.admin_name],
174                              os.path.join(self.dbdir, "psql-%d.log" % count))                              os.path.join(self.dbdir, "psql-%d.log" % count))
175              except:              except RuntimeError:
176                  pass                  pass
177                except:
178                    traceback.print_exc()
179              else:              else:
180                  break                  break
181              time.sleep(0.5)              time.sleep(0.5)
# Line 187  class PostgreSQLServer: Line 198  class PostgreSQLServer:
198          run_command(["pg_ctl", "-m", "fast", "-D", self.dbdir, "stop"],          run_command(["pg_ctl", "-m", "fast", "-D", self.dbdir, "stop"],
199                      os.path.join(self.dbdir, "pg_ctl-stop.log"))                      os.path.join(self.dbdir, "pg_ctl-stop.log"))
200    
201      def new_postgis_db(self, dbname, tables = None):      def new_postgis_db(self, dbname, tables = None, reference_systems = None):
202          """Create and return a new PostGISDatabase object using self as server          """Create and return a new PostGISDatabase object using self as server
203          """          """
204          db = PostGISDatabase(self, self.postgis_sql, dbname, tables = tables)          db = PostGISDatabase(self, self.postgis_sql, dbname, tables = tables,
205                                 reference_systems = reference_systems)
206          db.initdb()          db.initdb()
207          self.known_dbs[dbname] = db          self.known_dbs[dbname] = db
208          return db          return db
209    
210      def get_static_data_db(self, dbname, tables):      def get_static_data_db(self, dbname, tables, reference_systems):
211          """Return a PostGISDatabase for a database with the given static data          """Return a PostGISDatabase for a database with the given static data
212    
213          If no databasse of the name dbname exists, create a new one via          If no databasse of the name dbname exists, create a new one via
# Line 205  class PostgreSQLServer: Line 217  class PostgreSQLServer:
217          indicated data, return that. If the already existing db uses          indicated data, return that. If the already existing db uses
218          different data raise a value error.          different data raise a value error.
219    
220          The tables argument should be a sequence of table specifications          If the database doesn't exist, create a new one via
221          where each specifications is a (tablename, shapefilename) pair.          self.new_postgis_db.
222    
223            The parameters tables and reference_systems have the same
224            meaning as for new_postgis_db.
225          """          """
226          db = self.known_dbs.get(dbname)          db = self.known_dbs.get(dbname)
227          if db is not None:          if db is not None:
228              if db.has_data(tables):              if db.has_data(tables, reference_systems):
229                  return db                  return db
230              raise ValueError("PostGISDatabase named %r doesn't have tables %r"              raise ValueError("PostGISDatabase named %r doesn't have tables %r"
231                               % (dbname, tables))                               % (dbname, tables))
232          return self.new_postgis_db(dbname, tables)          return self.new_postgis_db(dbname, tables, reference_systems)
233    
234      def get_default_static_data_db(self):      def get_default_static_data_db(self):
235          dbname = "PostGISStaticTests"          dbname = "PostGISStaticTests"
236          tables = [("landmarks", os.path.join("..", "Data", "iceland",          srids = [(1, "proj=longlat datum=WGS84")]
237                                               "cultural_landmark-point.shp")),          tables = [
238                    ("political", os.path.join("..", "Data", "iceland",              # Direct copies of the shapefiles. The shapeids are exactly
239                # the same, except where changed with "gid_offset", of
240                # course.  Note that the test implementation requires that
241                # all the landmard tables use an gid_offset of 1000.
242                ("landmarks", os.path.join("..", "Data", "iceland",
243                                           "cultural_landmark-point.shp"),
244                 [("gid_offset", 1000)]),
245                ("political", os.path.join("..", "Data", "iceland",
246                                               "political.shp")),                                               "political.shp")),
247                    ("roads", os.path.join("..", "Data", "iceland",              ("roads", os.path.join("..", "Data", "iceland",
248                                           "roads-line.shp"))]                                           "roads-line.shp")),
249          return self.get_static_data_db(dbname, tables)  
250                # The polygon data as a MULTIPOLYGON geometry type
251                ("political_multi", os.path.join("..", "Data", "iceland",
252                                                 "political.shp"),
253                 [("force_wkt_type", "MULTIPOLYGON")]),
254    
255                # Copy of landmarks but using an srid != -1
256                ("landmarks_srid", os.path.join("..", "Data", "iceland",
257                                           "cultural_landmark-point.shp"),
258                 [("gid_offset", 1000),
259                  ("srid", 1)]),
260    
261                # Copy of landmarks with a gid column called "point_id" instead
262                # of "gid" and using an srid != -1.
263                ("landmarks_point_id", os.path.join("..", "Data", "iceland",
264                                                    "cultural_landmark-point.shp"),
265                 [("gid_offset", 1000),
266                  ("srid", 1),
267                  ("gid_column", "point_id")]),
268                ]
269            return self.get_static_data_db(dbname, tables, srids)
270    
271        def connection_params(self, user):
272            """Return the connection parameters for the given user
273    
274            The return value is a dictionary suitable as keyword argument
275            list to PostGISConnection. The user parameter may be either
276            'admin' to connect as admin or 'user' to connect as an
277            unprivileged user.
278            """
279            return {"host": self.host, "port": self.port,
280                    "user": getattr(self, user + "_name"),
281                    "password": getattr(self, user + "_password")}
282    
283        def connection_string(self, user):
284            """Return (part of) the connection string to pass to psycopg.connect
285    
286            The string contains host, port, user and password. The user
287            parameter must be either 'admin' or 'user', as for
288            connection_params.
289            """
290            params = []
291            for key, value in self.connection_params(user).items():
292                # FIXME: this doesn't do quiting correctly but that
293                # shouldn't be much of a problem (people shouldn't be using
294                # single quotes in filenames anyway :) )
295                params.append("%s='%s'" % (key, value))
296            return " ".join(params)
297    
298        def execute_sql(self, dbname, user, sql):
299            """Execute the sql statament
300    
301            The user parameter us used as in connection_params. The dbname
302            parameter must be the name of a database in the cluster.
303            """
304            conn = psycopg.connect("dbname=%s " % dbname
305                                   + self.connection_string(user))
306            cursor = conn.cursor()
307            cursor.execute(sql)
308            conn.commit()
309            conn.close()
310    
311        def require_authentication(self, required):
312            """Switch authentication requirements on or off
313    
314            When started for the first time no passwords are required. Some
315            tests want to explicitly test whether Thuban's password
316            infrastructure works and switch password authentication on
317            explicitly. When switching it on, there should be a
318            corresponding call to switch it off again in the test case'
319            tearDown method or in a finally: block.
320            """
321            if required:
322                contents = "local all password\n"
323            else:
324                contents = "local all trust\n"
325            f = open(os.path.join(self.dbdir, "pg_hba.conf"), "w")
326            f.write(contents)
327            f.close()
328            run_command(["pg_ctl", "-D", self.dbdir, "reload"],
329                        os.path.join(self.dbdir, "pg_ctl-reload.log"))
330    
331    
332        def create_user(self, username, password):
333            """Create user username with password in the database"""
334            self.execute_sql("template1", "admin",
335                             "CREATE USER %s PASSWORD '%s';" % (username,password))
336    
337        def alter_user(self, username, password):
338            """Change the user username's password in the database"""
339            self.execute_sql("template1", "admin",
340                             "ALTER USER %s PASSWORD '%s';" % (username,password))
341    
342    
343  class PostGISDatabase:  class PostGISDatabase:
344    
345      """A PostGIS database in a PostgreSQLServer"""      """A PostGIS database in a PostgreSQLServer"""
346    
347      def __init__(self, server, postgis_sql, dbname, tables = None):      def __init__(self, server, postgis_sql, dbname, tables = None,
348                     reference_systems = ()):
349            """Initialize the PostGISDatabase
350    
351            Parameters:
352    
353                server -- The PostgreSQLServer instance containing the
354                    database
355    
356                postgis_sql -- Filename of the postgis.sql file with the
357                    postgis initialization code
358    
359                dbname -- The name of the database
360    
361                tables -- Optional description of tables to create in the
362                    new database. If given it should be a list of
363                    (tablename, shapefilename) pairs meaning that a table
364                    tablename will be created with the contents of the given
365                    shapefile or (tablename, shapefilename, extraargs)
366                    triples. The extraargs should be a list of key, value
367                    pairs to use as keyword arguments to upload_shapefile.
368    
369                reference_systems -- Optional description of spatial
370                    reference systems.  If given, it should be a sequence of
371                    (srid, params) pairs where srid is the srid defined by
372                    the proj4 paramter string params.  The srid can be given
373                    as an extra parameter in the tables list.
374            """
375          self.server = server          self.server = server
376          self.postgis_sql = postgis_sql          self.postgis_sql = postgis_sql
377          self.dbname = dbname          self.dbname = dbname
378          self.tables = tables          self.tables = tables
379            if reference_systems:
380                self.reference_systems = reference_systems
381            else:
382                # Make sure that it's a sequence we can iterate over even if
383                # the parameter's None
384                self.reference_systems = ()
385    
386      def initdb(self):      def initdb(self):
387          """Remove the old db directory and create and initialize a new database          """Remove the old db directory and create and initialize a new database
388          """          """
389          run_command(["createdb", "-p", str(self.server.port),          run_command(["createdb", "-p", str(self.server.port),
390                       "-h", self.server.host, self.dbname],                       "-h", self.server.host, "-U", self.server.admin_name,
391                         self.dbname],
392                      os.path.join(self.server.dbdir, "createdb.log"))                      os.path.join(self.server.dbdir, "createdb.log"))
393          run_command(["createlang", "-p", str(self.server.port),          run_command(["createlang", "-p", str(self.server.port),
394                       "-h", self.server.host, "plpgsql", self.dbname],                       "-h", self.server.host,  "-U", self.server.admin_name,
395                         "plpgsql", self.dbname],
396                      os.path.join(self.server.dbdir, "createlang.log"))                      os.path.join(self.server.dbdir, "createlang.log"))
397          # for some reason psql doesn't exit with an error code if the          # for some reason psql doesn't exit with an error code if the
398          # file given as -f doesn't exist, so we check manually by trying          # file given as -f doesn't exist, so we check manually by trying
# Line 255  class PostGISDatabase: Line 401  class PostGISDatabase:
401          f.close()          f.close()
402          del f          del f
403          run_command(["psql", "-f", self.postgis_sql, "-d", self.dbname,          run_command(["psql", "-f", self.postgis_sql, "-d", self.dbname,
404                       "-p", str(self.server.port), "-h", self.server.host],                       "-p", str(self.server.port), "-h", self.server.host,
405                         "-U", self.server.admin_name],
406                       os.path.join(self.server.dbdir, "psql.log"))                       os.path.join(self.server.dbdir, "psql.log"))
407    
408            self.server.execute_sql(self.dbname, "admin",
409                                    "GRANT SELECT ON geometry_columns TO PUBLIC;")
410            self.server.execute_sql(self.dbname, "admin",
411                                    "GRANT SELECT ON spatial_ref_sys TO PUBLIC;")
412    
413            for srid, params in self.reference_systems:
414                self.server.execute_sql(self.dbname, "admin",
415                                        "INSERT INTO spatial_ref_sys VALUES"
416                                        " (%d, '', %d, '', '%s');"
417                                        % (srid, srid, params))
418          if self.tables is not None:          if self.tables is not None:
419              for tablename, shapefile in self.tables:              def unpack(item):
420                  upload_shapefile(shapefile, self, tablename)                  extra = {"force_wkt_type": None, "gid_offset": 0,
421                             "srid": -1}
422      def has_data(self, tables):                  if len(info) == 2:
423          return self.tables == tables                      tablename, shapefile = info
424                    else:
425                        tablename, shapefile, kw = info
426                        for key, val in kw:
427                            extra[key] = val
428                    return tablename, shapefile, extra
429    
430                for info in self.tables:
431                    tablename, shapefile, kw = unpack(info)
432                    upload_shapefile(shapefile, self, tablename, **kw)
433    
434        def has_data(self, tables, reference_systems):
435            return (self.tables == tables
436                    and self.reference_systems == reference_systems)
437    
438    
439  def find_postgis_sql():  def find_postgis_sql():
# Line 322  def reason_for_not_running_tests(): Line 492  def reason_for_not_running_tests():
492         The name of the postgis_sql file is determined by find_postgis_sql()         The name of the postgis_sql file is determined by find_postgis_sql()
493       - psycopg can be imported successfully.       - psycopg can be imported successfully.
494      """      """
495        # run_command currently uses Popen4 which is not available under
496        # Windows, for example.
497        if not hasattr(popen2, "Popen4"):
498            return "Can't run PostGIS test because popen2.Popen4 does not exist"
499    
500      try:      try:
501          run_command(["pg_ctl", "--help"], None)          run_command(["pg_ctl", "--help"], None)
502      except RuntimeError:      except RuntimeError:
# Line 354  def skip_if_no_postgis(): Line 529  def skip_if_no_postgis():
529      if _cannot_run_postgis_tests:      if _cannot_run_postgis_tests:
530          raise support.SkipTest(_cannot_run_postgis_tests)          raise support.SkipTest(_cannot_run_postgis_tests)
531    
532  def point_to_wkt(coords):  def skip_if_addgeometrycolumn_does_not_use_quote_ident():
533        """Skip a test if the AddGeometryColumn function doesn't use quote_ident
534    
535        If the AddGeometryColumn function doesn't use quote_ident it doesn't
536        support unusual table or column names properly, that is, it will
537        fail with errors for names that contain spaces or double quotes.
538    
539        The test performed by this function is a bit simplistic because it
540        only tests whether the string 'quote_ident' occurs anywhere in the
541        postgis.sql file. This will hopefully work because when this was
542        fixed in postgis CVS AddGeometryColumn was the first function to use
543        quote_ident.
544        """
545        f = file(find_postgis_sql())
546        content = f.read()
547        f.close()
548        if content.find("quote_ident") < 0:
549            raise support.SkipTest("AddGeometryColumn doesn't use quote_ident")
550    
551    def coords_to_point(coords):
552      """Return string with a WKT representation of the point in coords"""      """Return string with a WKT representation of the point in coords"""
553      x, y = coords[0]      x, y = coords[0]
554      return "POINT(%r %r)" % (x, y)      return "POINT(%r %r)" % (x, y)
555    
556  def polygon_to_wkt(coords):  def coords_to_polygon(coords):
557      """Return string with a WKT representation of the polygon in coords"""      """Return string with a WKT representation of the polygon in coords"""
558      poly = []      poly = []
559      for ring in coords:      for ring in coords:
560          poly.append(", ".join(["%r %r" % p for p in ring]))          poly.append(", ".join(["%r %r" % p for p in ring]))
561      return "POLYGON((%s))" % "), (".join(poly)      return "POLYGON((%s))" % "), (".join(poly)
562    
563  def arc_to_wkt(coords):  def coords_to_multilinestring(coords):
564      """Return string with a WKT representation of the arc in coords"""      """Return string with a WKT representation of the arc in coords"""
565      poly = []      poly = []
566      for ring in coords:      for ring in coords:
567          poly.append(", ".join(["%r %r" % p for p in ring]))          poly.append(", ".join(["%r %r" % p for p in ring]))
568      return "MULTILINESTRING((%s))" % "), (".join(poly)      return "MULTILINESTRING((%s))" % "), (".join(poly)
569    
570  def upload_shapefile(filename, db, tablename):  def coords_to_multipolygon(coords):
571        """Return string with a WKT representation of the polygon in coords"""
572        poly = []
573        for ring in coords:
574            poly.append(", ".join(["%r %r" % p for p in ring]))
575        return "MULTIPOLYGON(((%s)))" % ")), ((".join(poly)
576    
577    wkt_converter = {
578        "POINT": coords_to_point,
579        "MULTILINESTRING": coords_to_multilinestring,
580        "POLYGON": coords_to_polygon,
581        "MULTIPOLYGON": coords_to_multipolygon,
582        }
583    
584    def upload_shapefile(filename, db, tablename, force_wkt_type = None,
585                         gid_offset = 0, gid_column = "gid", srid = -1):
586        """Upload a shapefile into a new database table
587    
588        Parameters:
589    
590        filename -- The name of the shapefile
591    
592        db -- The PostGISDatabase instance representing the database
593    
594        tablename -- The name of the table to create and into which the data
595                    is to be inserted
596    
597        force_wkt_type -- If given the real WKT geometry type to use instead
598                    of the default that would be chosen based on the type of
599                    the shapefile
600    
601        gid_offset -- A number to add to the shapeid to get the value for
602                    the gid column (default 0)
603    
604        gid_column -- The name of the column with the shape ids.  Default
605                      'gid'.  If None, no gid column will be created.  The
606                      name is directly used in SQL statements, so if it
607                      contains unusualy characters the caller should provide
608                      a suitable quoted string.
609    
610        srid -- The srid of the spatial references system used by the table
611                and the data
612        """
613      import dbflib, shapelib      import dbflib, shapelib
614    
615        # We build this map here because we need shapelib which can only be
616        # imported after support.initthuban has been called which we can't
617        # easily do in this module because it's imported by support.
618        shp_to_wkt = {
619            shapelib.SHPT_POINT: "POINT",
620            shapelib.SHPT_ARC: "MULTILINESTRING",
621            shapelib.SHPT_POLYGON: "POLYGON",
622            }
623    
624      server = db.server      server = db.server
625      dbname = db.dbname      dbname = db.dbname
626      conn = psycopg.connect("host=%s port=%s dbname=%s"      conn = psycopg.connect("dbname=%s " % dbname
627                             % (server.host, server.port, dbname))                             + db.server.connection_string("admin"))
628      cursor = conn.cursor()      cursor = conn.cursor()
629    
630      shp = shapelib.ShapeFile(filename)      shp = shapelib.ShapeFile(filename)
# Line 388  def upload_shapefile(filename, db, table Line 633  def upload_shapefile(filename, db, table
633                 dbflib.FTInteger: "INTEGER",                 dbflib.FTInteger: "INTEGER",
634                 dbflib.FTDouble: "DOUBLE PRECISION"}                 dbflib.FTDouble: "DOUBLE PRECISION"}
635    
636      insert_formats = ["%(gid)s"]      insert_formats = []
637      fields = ["gid INT"]      if gid_column:
638            insert_formats.append("%(gid)s")
639    
640        fields = []
641        fields_decl = []
642        if gid_column:
643            fields.append(gid_column)
644            fields_decl.append("%s INT" % gid_column)
645      for i in range(dbf.field_count()):      for i in range(dbf.field_count()):
646          ftype, name, width, prec = dbf.field_info(i)          ftype, name, width, prec = dbf.field_info(i)
647          fields.append("%s %s" % (name, typemap[ftype]))          fields.append(name)
648            fields_decl.append("%s %s" % (name, typemap[ftype]))
649          insert_formats.append("%%(%s)s" % name)          insert_formats.append("%%(%s)s" % name)
650      stmt = "CREATE TABLE %s (\n    %s\n);" % (tablename,      stmt = "CREATE TABLE %s (\n    %s\n);" % (tablename,
651                                                ",\n    ".join(fields))                                                ",\n    ".join(fields_decl))
652      cursor.execute(stmt)      cursor.execute(stmt)
653      #print stmt      #print stmt
654    
655      numshapes, shapetype, mins, maxs = shp.info()      numshapes, shapetype, mins, maxs = shp.info()
656      if shapetype == shapelib.SHPT_POINT:      wkttype =  shp_to_wkt[shapetype]
657          convert = point_to_wkt      if force_wkt_type:
658          wkttype = "POINT"          wkttype = force_wkt_type
659      elif shapetype == shapelib.SHPT_POLYGON:      convert = wkt_converter[wkttype]
         convert = polygon_to_wkt  
         wkttype = "POLYGON"  
     elif shapetype == shapelib.SHPT_ARC:  
         convert = arc_to_wkt  
         wkttype = "MULTILINESTRING"  
     else:  
         raise ValueError("Unsupported Shapetype %r" % shapetype)  
660    
661      cursor.execute("select AddGeometryColumn('%(dbname)s',"      cursor.execute("select AddGeometryColumn('%(dbname)s',"
662                     "'%(tablename)s', 'the_geom', '-1', '%(wkttype)s', 2);"                     "'%(tablename)s', 'the_geom', %(srid)d, '%(wkttype)s', 2);"
663                     % locals())                     % locals())
664        fields.append("the_geom")
665        insert_formats.append("GeometryFromText(%(the_geom)s, %(srid)d)")
666    
667      insert_formats.append("GeometryFromText(%(the_geom)s, -1)")      insert = ("INSERT INTO %s (%s) VALUES (%s)"
668                  % (tablename, ", ".join(fields), ", ".join(insert_formats)))
     insert = ("INSERT INTO %s VALUES (%s)"  
               % (tablename, ", ".join(insert_formats)))  
669    
670      for i in range(numshapes):      for i in range(numshapes):
671          data = dbf.read_record(i)          data = dbf.read_record(i)
672          data["tablename"] = tablename          data["tablename"] = tablename
673          data["gid"] = i          if gid_column:
674                data["gid"] = i + gid_offset
675            data["srid"] = srid
676          data["the_geom"] = convert(shp.read_object(i).vertices())          data["the_geom"] = convert(shp.read_object(i).vertices())
677          #print insert % data          #print insert % data
678          cursor.execute(insert, data)          cursor.execute(insert, data)
679    
680        cursor.execute("GRANT SELECT ON %s TO PUBLIC;" % tablename)
681    
682      conn.commit()      conn.commit()

Legend:
Removed from v.1605  
changed lines
  Added in v.2096

[email protected]
ViewVC Help
Powered by ViewVC 1.1.26