Source file database.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
(** Database connection and configuration utilities. *)

open Lwt.Infix
open Caqti_request.Infix
open Caqti_type.Std

let get_hostname (uri : Uri.t) : (string, Types.error) result =
  match Uri.host uri with
  | Some host -> Ok host
  | None ->
      Error
        (Types.DatabaseError
           (Types.UrlParseError "Could not parse host from DATABASE_URL"))

let get_port (uri : Uri.t) : (int, Types.error) result =
  match Uri.port uri with
  | Some port -> Ok port
  | None ->
      Error
        (Types.DatabaseError
           (Types.UrlParseError "Could not parse port from DATABASE_URL"))

let get_database (uri : Uri.t) : (string, Types.error) result =
  let path = Uri.path uri in
  if String.length path = 0 then
    Error
      (Types.DatabaseError
         (Types.UrlParseError
            "Could not parse database from DATABASE_URL (empty path)"))
  else
    match path.[0] with
    | '/' ->
        let db_name = String.sub path 1 (String.length path - 1) in
        if String.length db_name = 0 then
          Error
            (Types.DatabaseError
               (Types.UrlParseError
                  "Could not parse database from DATABASE_URL (empty database \
                   name)"))
        else Ok db_name
    | _ ->
        Error
          (Types.DatabaseError
             (Types.UrlParseError
                "Could not parse database from DATABASE_URL (invalid path \
                 format)"))

(** Replace the password in a connection URL with a fixed [*****] mask for safe
    display and logging. URLs without a password (or SQLite paths) are returned
    unchanged. The mask is a fixed width so it does not reveal the password
    length.

    The substitution goes via an alphanumeric placeholder because
    [Uri.to_string] would percent-encode the [*] characters directly. *)
let redact_url (url : string) : string =
  let uri = Uri.of_string url in
  match Uri.password uri with
  | None -> url
  | Some _ -> (
      let token = "MIGRAPWREDACTED0" in
      let s = Uri.to_string (Uri.with_password uri (Some token)) in
      let tl = String.length token in
      let rec find i =
        if i + tl > String.length s then None
        else if String.sub s i tl = token then Some i
        else find (i + 1)
      in
      match find 0 with
      | None -> s
      | Some i ->
          String.sub s 0 i ^ "*****"
          ^ String.sub s (i + tl) (String.length s - i - tl))

let get_database_url () : (string, Types.error) result =
  match Sys.getenv_opt "DATABASE_URL" with
  | Some url -> Ok url
  | None ->
      Error
        (Types.DatabaseError
           (Types.UrlParseError "DATABASE_URL environment variable not set"))

(** Check if [haystack] contains [needle]. Pure scan - avoids the [Str] library,
    whose global match state is not safe under Lwt's concurrent scheduling. *)
let string_contains (haystack : string) (needle : string) : bool =
  let hl = String.length haystack and nl = String.length needle in
  if nl = 0 then true
  else if nl > hl then false
  else
    let rec loop i =
      if i > hl - nl then false
      else if String.sub haystack i nl = needle then true
      else loop (i + 1)
    in
    loop 0

let is_missing_driver_error (err_msg : string) : bool =
  string_contains err_msg "suitable driver"
  || string_contains err_msg "not found"

(** Build a helpful "driver not installed" message for [database_url]'s scheme.
    Assumes the failure is a missing-driver error (see
    {!is_missing_driver_error}). *)
let missing_driver_message (database_url : string) : string =
  let scheme =
    Uri.of_string database_url |> Uri.scheme |> Option.value ~default:"unknown"
  in
  let driver_name, install_cmd =
    match scheme with
    | "postgresql" | "postgres" ->
        ("PostgreSQL", "opam install caqti-driver-postgresql")
    | "mariadb" | "mysql" ->
        ("MariaDB/MySQL", "opam install caqti-driver-mariadb")
    | "sqlite3" -> ("SQLite", "opam install caqti-driver-sqlite3")
    | other -> (other, Printf.sprintf "Unknown database scheme: %s" other)
  in
  Printf.sprintf
    "No database driver found for '%s://'\n\n\
     The %s driver is not installed. To fix this:\n\
    \  %s\n\n\
     Available drivers:\n\
    \  - caqti-driver-postgresql  (for postgresql://)\n\
    \  - caqti-driver-mariadb     (for mariadb://, mysql://)\n\
    \  - caqti-driver-sqlite3     (for sqlite3://)"
    scheme driver_name install_cmd

(** A well-formed URL authority has at most one ['@'] (separating credentials
    from host). More than one strongly suggests an unencoded ['@'] in the
    password, which makes the URL parse with the wrong host. *)
let likely_unencoded_credentials (url : string) : bool =
  String.fold_left (fun n c -> if c = '@' then n + 1 else n) 0 url > 1

(** Connect to database using connection string Returns a single connection (use
    for one-off operations or transactions) *)
let connect_db (database_url : string) :
    (Types.db_conn, Types.error) result Lwt.t =
  let normalized_url = Dialect.normalize_url database_url in
  Caqti_lwt_unix.connect (Uri.of_string normalized_url) >|= function
  | Ok conn -> Ok (conn :> Types.db_conn)
  | Error err ->
      if is_missing_driver_error (Caqti_error.show err) then
        Error
          (Types.DatabaseError
             (Types.ValidationError (missing_driver_message database_url)))
      else if likely_unencoded_credentials database_url then
        Error
          (Types.DatabaseError
             (Types.ValidationError
                (Printf.sprintf
                   "%s\n\n\
                    Hint: the connection URL contains more than one '@'. If \
                    your username or password contains '@' (or ':' '/' '?' \
                    '#'), percent-encode it - e.g. '@' becomes '%%40' - \
                    otherwise it is misread as the host."
                   (Caqti_error.show err))))
      else
        Error (Types.DatabaseError (Types.ConnectionFailed ("connect_db", err)))

(** Connect to database and execute a function, then close connection

    Exceptions raised by [f] are caught and converted to error results. The
    error message includes the exception trace for debugging.

    @param database_url Database connection URL
    @param f Function to execute with database connection
    @return Result of [f] or error message *)
let with_db (database_url : string) (f : Types.db_conn -> 'a Lwt.t) :
    ('a, Types.error) Lwt_result.t =
  connect_db database_url >>= function
  | Error err -> Lwt.return_error err
  | Ok db ->
      let module Db = (val db : Caqti_lwt.CONNECTION) in
      Lwt.finalize
        (fun () ->
          Lwt.catch
            (fun () -> f db >|= fun result -> Ok result)
            (fun exn ->
              Lwt.return_error
                (Types.DatabaseError
                   (Types.ValidationError
                      (Printf.sprintf "Unexpected error: %s"
                         (Printexc.to_string exn))))))
        (fun () -> Db.disconnect ())

(** Build connection URL for admin database (dialect-aware) Used for
    creating/dropping databases

    @param dialect Database dialect type
    @param uri Parsed database URL
    @return Admin database connection URL or error *)
let get_admin_database_url (dialect : Dialect.t) (uri : Uri.t) :
    (string, Types.error) result =
  let module D = (val Dialect.get_dialect dialect : Dialect.DIALECT) in
  match D.admin_database with
  | None ->
      Error
        (Types.DatabaseError
           (Types.ValidationError
              "This database type does not support admin database connections"))
  | Some admin_db -> (
      match get_hostname uri with
      | Error err -> Error err
      | Ok _host ->
          (* Derive the admin URL by transforming the original URI rather than
             rebuilding it from parts, so userinfo (including the password),
             query parameters (e.g. sslmode), and IPv6 host bracketing are
             preserved. *)
          let scheme =
            match dialect with
            | Dialect.PostgreSQL -> "postgresql"
            | Dialect.MariaDB -> "mariadb"
            | Dialect.SQLite -> "sqlite3"
          in
          let port =
            match Uri.port uri with
            | Some p -> p
            | None -> Option.value D.default_port ~default:5432
          in
          let admin_uri = Uri.with_scheme uri (Some scheme) in
          let admin_uri = Uri.with_port admin_uri (Some port) in
          let admin_uri = Uri.with_path admin_uri ("/" ^ admin_db) in
          Ok (Uri.to_string admin_uri))

(** Connect to the admin database for [database_url]'s dialect and run [f] with
    the connection and the target database name, disconnecting afterwards. For
    server dialects only (SQLite has no admin database). *)
let with_admin_connection (dialect : Dialect.t) (database_url : string)
    (f : Types.db_conn -> string -> (unit, Types.error) Lwt_result.t) :
    (unit, Types.error) Lwt_result.t =
  let uri = Uri.of_string database_url in
  match get_database uri with
  | Error err -> Lwt.return_error err
  | Ok db_name when String.contains db_name '/' ->
      (* A '/' is not a valid character in a database/schema name; it means the
         URL path has an extra segment. Reject it rather than splicing it into
         CREATE/DROP DATABASE. *)
      Lwt.return_error
        (Types.DatabaseError
           (Types.UrlParseError
              (Printf.sprintf
                 "Invalid database name %S: a database name cannot contain '/' \
                  (check the path in your DATABASE_URL)"
                 db_name)))
  | Ok db_name -> (
      match get_admin_database_url dialect uri with
      | Error err -> Lwt.return_error err
      | Ok admin_url -> (
          connect_db admin_url >>= function
          | Error err -> Lwt.return_error err
          | Ok db ->
              let module Conn = (val db : Caqti_lwt.CONNECTION) in
              Lwt.finalize
                (fun () -> f db db_name)
                (fun () -> Conn.disconnect ())))

(** Create database if it doesn't exist (dialect-aware)

    For SQLite: Database file will be created automatically on first connection.
    For PostgreSQL/MariaDB: Connects to admin database to execute CREATE
    DATABASE.

    Note: For server-based databases, this function checks existence then
    creates, which has a small race window. If two processes call this
    simultaneously, one may fail. This is acceptable for typical use cases
    (development workflows).

    @param database_url Database connection URL
    @return Ok () or error *)
let create_database (database_url : string) : (unit, Types.error) Lwt_result.t =
  (* Detect database dialect from URL *)
  match Dialect.detect_from_url database_url with
  | Error msg ->
      Lwt.return_error (Types.DatabaseError (Types.UrlParseError msg))
  | Ok dialect ->
      let module D = (val Dialect.get_dialect dialect : Dialect.DIALECT) in
      if dialect = Dialect.SQLite then
        (* SQLite has no server-side database to create: a file-backed database
           is created automatically on first connection, and :memory: needs
           nothing. Either way there is no work to do here. *)
        Lwt.return_ok ()
      else
        with_admin_connection dialect database_url (fun db db_name ->
            let module Conn = (val db : Caqti_lwt.CONNECTION) in
            let check_query = (string ->! bool) D.database_exists_sql in
            Conn.find check_query db_name >>= function
            | Error err ->
                Lwt.return_error
                  (Types.DatabaseError
                     (Types.QueryFailed ("check database existence", err)))
            | Ok true -> Lwt.return_ok ()
            | Ok false -> (
                let create_query =
                  (unit ->. unit) (D.create_database_sql db_name)
                in
                Conn.exec create_query () >>= function
                | Error err ->
                    Lwt.return_error
                      (Types.DatabaseError
                         (Types.QueryFailed ("create database", err)))
                | Ok () -> Lwt.return_ok ()))

(** Drop database if it exists (dialect-aware)

    For SQLite: Deletes the database file from the filesystem. For
    PostgreSQL/MariaDB: Connects to admin database to execute DROP DATABASE.

    @param database_url Database connection URL
    @return Ok () or error *)
let drop_database (database_url : string) : (unit, Types.error) Lwt_result.t =
  (* Detect database dialect from URL *)
  match Dialect.detect_from_url database_url with
  | Error msg ->
      Lwt.return_error (Types.DatabaseError (Types.UrlParseError msg))
  | Ok dialect ->
      let module D = (val Dialect.get_dialect dialect : Dialect.DIALECT) in
      if dialect = Dialect.SQLite then
        let normalized_url = Dialect.normalize_url database_url in
        let uri = Uri.of_string normalized_url in
        let path = Uri.path uri in
        if path = ":memory:" then Lwt.return_ok ()
        else
          Lwt.catch
            (fun () ->
              if Sys.file_exists path then
                Lwt_unix.unlink path >|= fun () -> Ok ()
              else Lwt.return_ok ())
            (fun exn ->
              Lwt.return_error
                (Types.DatabaseError
                   (Types.ValidationError
                      (Printf.sprintf "Failed to delete SQLite file: %s"
                         (Printexc.to_string exn)))))
      else
        with_admin_connection dialect database_url (fun db db_name ->
            let module Conn = (val db : Caqti_lwt.CONNECTION) in
            let drop_query = (unit ->. unit) (D.drop_database_sql db_name) in
            Conn.exec drop_query () >>= function
            | Error err ->
                Lwt.return_error
                  (Types.DatabaseError
                     (Types.QueryFailed ("drop database", err)))
            | Ok () -> Lwt.return_ok ())