Source file middleware_csrf.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
open Lwt.Syntax
let log_src = Logs.Src.create ~doc:"CSRF Middleware" "sihl.middleware.csrf"
module Logs = (val Logs.src_log log_src : Logs.LOG)
let key : string Opium_kernel.Hmap.key =
Opium_kernel.Hmap.Key.create ("csrf token", Sexplib.Std.sexp_of_string)
;;
exception Crypto_failed of string
let find req = Opium_kernel.Hmap.find_exn key (Opium_kernel.Request.env req)
let find_opt req =
try Some (find req) with
| _ -> None
;;
let set token req =
let env = Opium_kernel.Request.env req in
let env = Opium_kernel.Hmap.add key token env in
{ req with env }
;;
module Make (TokenService : Token.Sig.SERVICE) (SessionService : Session.Sig.SERVICE) =
struct
let create_secret ctx session =
let* token = TokenService.create ctx ~kind:"csrf" ~length:20 () in
let* () = SessionService.set ctx session ~key:"csrf" ~value:token.id in
Lwt.return token
;;
let m () =
let filter handler req =
let ctx = Http.Request.to_ctx req in
let session =
match Middleware_session.find_opt req with
| Some session -> session
| None ->
Logs.info (fun m -> m "Have you applied the session middleware?");
raise (Crypto_failed "No session found")
in
let* id = SessionService.get ctx session ~key:"csrf" in
let* secret =
match id with
| None -> create_secret ctx session
| Some token_id ->
let* token = TokenService.find_by_id_opt ctx token_id in
(match token with
| None -> create_secret ctx session
| Some secret -> Lwt.return secret)
in
let secret_length = String.length secret.value in
let salt = Core.Random.bytes ~nr:secret_length in
let secret_value = secret.value |> String.to_seq |> List.of_seq in
let encrypted =
match Utils.Encryption.xor salt secret_value with
| None ->
Logs.err (fun m -> m "MIDDLEWARE: Failed to encrypt CSRF secret");
raise @@ Crypto_failed "Failed to encrypt CSRF secret"
| Some enc -> enc
in
let token =
encrypted
|> List.append salt
|> List.to_seq
|> String.of_seq
|> Base64.encode_string ~alphabet:Base64.uri_safe_alphabet
in
let req = set token req in
if Http.Request.is_get req
then handler req
else
let* value = Http.Request.urlencoded "csrf" req in
match value with
| None -> Http.Response.(create () |> set_status 403) |> Lwt.return
| Some value ->
let decoded = Base64.decode ~alphabet:Base64.uri_safe_alphabet value in
let decoded =
match decoded with
| Ok decoded -> decoded
| Error (`Msg msg) ->
Logs.err (fun m -> m "MIDDLEWARE: Failed to decode CSRF token. %s" msg);
raise @@ Crypto_failed ("Failed to decode CSRF token. " ^ msg)
in
let salted_cipher = decoded |> String.to_seq |> List.of_seq in
let decrypted_secret =
match
Utils.Encryption.decrypt_with_salt
~salted_cipher
~salt_length:(List.length salted_cipher / 2)
with
| None ->
Logs.err (fun m -> m "MIDDLEWARE: Failed to decrypt CSRF token");
raise @@ Crypto_failed "Failed to decrypt CSRF token"
| Some dec -> dec
in
let* provided_secret =
TokenService.find_opt ctx (decrypted_secret |> List.to_seq |> String.of_seq)
in
(match provided_secret with
| Some ps ->
if not @@ Token.equal secret ps
then
Http.Response.(create () |> set_status 403) |> Lwt.return
else
let* () = TokenService.invalidate ctx ps in
handler req
| None ->
Http.Response.(create () |> set_status 403) |> Lwt.return)
in
Opium_kernel.Rock.Middleware.create ~name:"csrf" ~filter
;;
end