]> granicus.if.org Git - postgresql/blobdiff - src/backend/commands/extension.c
Add CASCADE support for CREATE EXTENSION.
[postgresql] / src / backend / commands / extension.c
index 6b92bdc5e0e821395604da976af2a0448e969b98..67b16a7a68d3d0a75e0eeda2da1db3769b3957bd 100644 (file)
 #include "catalog/pg_type.h"
 #include "commands/alter.h"
 #include "commands/comment.h"
+#include "commands/defrem.h"
 #include "commands/extension.h"
 #include "commands/schemacmds.h"
 #include "funcapi.h"
 #include "mb/pg_wchar.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "storage/fd.h"
 #include "tcop/utility.h"
 #include "utils/builtins.h"
@@ -1165,18 +1167,25 @@ find_update_path(List *evi_list,
 }
 
 /*
- * CREATE EXTENSION
+ * CREATE EXTENSION worker
+ *
+ * When CASCADE is specified CreateExtensionInternal() recurses if required
+ * extensions need to be installed. To sanely handle cyclic dependencies
+ * cascade_parent contains the dependency chain leading to the current
+ * invocation; thus allowing to error out if there's a cyclic dependency.
  */
-ObjectAddress
-CreateExtension(CreateExtensionStmt *stmt)
+static ObjectAddress
+CreateExtensionInternal(CreateExtensionStmt *stmt, List *parents)
 {
        DefElem    *d_schema = NULL;
        DefElem    *d_new_version = NULL;
        DefElem    *d_old_version = NULL;
-       char       *schemaName;
-       Oid                     schemaOid;
+       DefElem    *d_cascade = NULL;
+       char       *schemaName = NULL;
+       Oid                     schemaOid = InvalidOid;
        char       *versionName;
        char       *oldVersionName;
+       bool            cascade = false;
        Oid                     extowner = GetUserId();
        ExtensionControlFile *pcontrol;
        ExtensionControlFile *control;
@@ -1187,41 +1196,6 @@ CreateExtension(CreateExtensionStmt *stmt)
        ListCell   *lc;
        ObjectAddress address;
 
-       /* Check extension name validity before any filesystem access */
-       check_valid_extension_name(stmt->extname);
-
-       /*
-        * Check for duplicate extension name.  The unique index on
-        * pg_extension.extname would catch this anyway, and serves as a backstop
-        * in case of race conditions; but this is a friendlier error message, and
-        * besides we need a check to support IF NOT EXISTS.
-        */
-       if (get_extension_oid(stmt->extname, true) != InvalidOid)
-       {
-               if (stmt->if_not_exists)
-               {
-                       ereport(NOTICE,
-                                       (errcode(ERRCODE_DUPLICATE_OBJECT),
-                                        errmsg("extension \"%s\" already exists, skipping",
-                                                       stmt->extname)));
-                       return InvalidObjectAddress;
-               }
-               else
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_DUPLICATE_OBJECT),
-                                        errmsg("extension \"%s\" already exists",
-                                                       stmt->extname)));
-       }
-
-       /*
-        * We use global variables to track the extension being created, so we can
-        * create only one extension at the same time.
-        */
-       if (creating_extension)
-               ereport(ERROR,
-                               (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-                                errmsg("nested CREATE EXTENSION is not supported")));
-
        /*
         * Read the primary control file.  Note we assume that it does not contain
         * any non-ASCII data, so there is no need to worry about encoding at this
@@ -1260,6 +1234,15 @@ CreateExtension(CreateExtensionStmt *stmt)
                                                 errmsg("conflicting or redundant options")));
                        d_old_version = defel;
                }
+               else if (strcmp(defel->defname, "cascade") == 0)
+               {
+                       if (d_cascade)
+                               ereport(ERROR,
+                                               (errcode(ERRCODE_SYNTAX_ERROR),
+                                                errmsg("conflicting or redundant options")));
+                       d_cascade = defel;
+                       cascade = defGetBoolean(d_cascade);
+               }
                else
                        elog(ERROR, "unrecognized option: %s", defel->defname);
        }
@@ -1337,33 +1320,37 @@ CreateExtension(CreateExtensionStmt *stmt)
        {
                /*
                 * User given schema, CREATE EXTENSION ... WITH SCHEMA ...
-                *
-                * It's an error to give a schema different from control->schema if
-                * control->schema is specified.
                 */
                schemaName = strVal(d_schema->arg);
 
-               if (control->schema != NULL &&
-                       strcmp(control->schema, schemaName) != 0)
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
-                               errmsg("extension \"%s\" must be installed in schema \"%s\"",
-                                          control->name,
-                                          control->schema)));
-
-               /* If the user is giving us the schema name, it must exist already */
+               /* If the user is giving us the schema name, it must exist already. */
                schemaOid = get_namespace_oid(schemaName, false);
        }
-       else if (control->schema != NULL)
+
+       if (control->schema != NULL)
        {
                /*
                 * The extension is not relocatable and the author gave us a schema
-                * for it.  We create the schema here if it does not already exist.
+                * for it.
+                *
+                * Unless CASCADE parameter was given, it's an error to give a schema
+                * different from control->schema if control->schema is specified.
                 */
+               if (schemaName && strcmp(control->schema, schemaName) != 0 &&
+                       !cascade)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+                               errmsg("extension \"%s\" must be installed in schema \"%s\"",
+                                          control->name,
+                                          control->schema)));
+
+               /* Always use the schema from control file for current extension. */
                schemaName = control->schema;
+
+               /* Find or create the schema in case it does not exist. */
                schemaOid = get_namespace_oid(schemaName, true);
 
-               if (schemaOid == InvalidOid)
+               if (!OidIsValid(schemaOid))
                {
                        CreateSchemaStmt *csstmt = makeNode(CreateSchemaStmt);
 
@@ -1375,16 +1362,17 @@ CreateExtension(CreateExtensionStmt *stmt)
 
                        /*
                         * CreateSchemaCommand includes CommandCounterIncrement, so new
-                        * schema is now visible
+                        * schema is now visible.
                         */
                        schemaOid = get_namespace_oid(schemaName, false);
                }
        }
-       else
+       else if (!OidIsValid(schemaOid))
        {
                /*
-                * Else, use the current default creation namespace, which is the
-                * first explicit entry in the search_path.
+                * Neither user nor author of the extension specified schema, use the
+                * current default creation namespace, which is the first explicit
+                * entry in the search_path.
                 */
                List       *search_path = fetch_search_path(false);
 
@@ -1423,16 +1411,65 @@ CreateExtension(CreateExtensionStmt *stmt)
                Oid                     reqext;
                Oid                     reqschema;
 
-               /*
-                * We intentionally don't use get_extension_oid's default error
-                * message here, because it would be confusing in this context.
-                */
                reqext = get_extension_oid(curreq, true);
                if (!OidIsValid(reqext))
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_UNDEFINED_OBJECT),
-                                        errmsg("required extension \"%s\" is not installed",
-                                                       curreq)));
+               {
+                       if (cascade)
+                       {
+                               CreateExtensionStmt *ces;
+                               ListCell   *lc;
+                               ObjectAddress addr;
+                               List *cascade_parents;
+
+                               /* Check extension name validity before trying to cascade */
+                               check_valid_extension_name(curreq);
+
+                               /* Check for cyclic dependency between extensions. */
+                               foreach(lc, parents)
+                               {
+                                       char       *pname = (char *) lfirst(lc);
+
+                                       if (strcmp(pname, curreq) == 0)
+                                               ereport(ERROR,
+                                                               (errcode(ERRCODE_INVALID_RECURSION),
+                                                                errmsg("cyclic dependency detected between extensions \"%s\" and \"%s\"",
+                                                                               curreq, stmt->extname)));
+                               }
+
+                               ereport(NOTICE,
+                                               (errmsg("installing required extension \"%s\"",
+                                                               curreq)));
+
+                               /* Create and execute new CREATE EXTENSION statement. */
+                               ces = makeNode(CreateExtensionStmt);
+                               ces->extname = curreq;
+
+                               /* Propagate the CASCADE option */
+                               ces->options = list_make1(d_cascade);
+
+                               /* Propagate the SCHEMA option if given. */
+                               if (d_schema && d_schema->arg)
+                                       ces->options = lappend(ces->options, d_schema);
+
+                               /*
+                                * Pass the current list of parents + the current extension to
+                                * the "child" CreateExtensionInternal().
+                                */
+                               cascade_parents =
+                                       lappend(list_copy(parents), stmt->extname);
+
+                               /* Create the required extension. */
+                               addr = CreateExtensionInternal(ces, cascade_parents);
+                               reqext = addr.objectId;
+                       }
+                       else
+                               ereport(ERROR,
+                                               (errcode(ERRCODE_UNDEFINED_OBJECT),
+                                                errmsg("required extension \"%s\" is not installed",
+                                                               curreq),
+                                                errhint("Use CREATE EXTENSION CASCADE to install required extensions too.")));
+               }
+
                reqschema = get_extension_schema(reqext);
                requiredExtensions = lappend_oid(requiredExtensions, reqext);
                requiredSchemas = lappend_oid(requiredSchemas, reqschema);
@@ -1473,6 +1510,52 @@ CreateExtension(CreateExtensionStmt *stmt)
        return address;
 }
 
+/*
+ * CREATE EXTENSION
+ */
+ObjectAddress
+CreateExtension(CreateExtensionStmt *stmt)
+{
+       /* Check extension name validity before any filesystem access */
+       check_valid_extension_name(stmt->extname);
+
+       /*
+        * Check for duplicate extension name.  The unique index on
+        * pg_extension.extname would catch this anyway, and serves as a backstop
+        * in case of race conditions; but this is a friendlier error message, and
+        * besides we need a check to support IF NOT EXISTS.
+        */
+       if (get_extension_oid(stmt->extname, true) != InvalidOid)
+       {
+               if (stmt->if_not_exists)
+               {
+                       ereport(NOTICE,
+                                       (errcode(ERRCODE_DUPLICATE_OBJECT),
+                                        errmsg("extension \"%s\" already exists, skipping",
+                                                       stmt->extname)));
+                       return InvalidObjectAddress;
+               }
+               else
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_DUPLICATE_OBJECT),
+                                        errmsg("extension \"%s\" already exists",
+                                                       stmt->extname)));
+       }
+
+       /*
+        * We use global variables to track the extension being created, so we can
+        * create only one extension at the same time.
+        */
+       if (creating_extension)
+               ereport(ERROR,
+                               (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+                                errmsg("nested CREATE EXTENSION is not supported")));
+
+
+       /* Finally create the extension. */
+       return CreateExtensionInternal(stmt, NIL);
+}
+
 /*
  * InsertExtensionTuple
  *