Source file py_wrap.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
module P = Py_base

module W = struct
  type 'a t =
    { wrap : 'a -> P.pyobject
    ; unwrap_exn : P.pyobject -> 'a
    ; name : string
    }
end
let un w = w.W.unwrap_exn

type 'a t =
  | Apply : 'a W.t * 'b t -> ('a -> 'b) t
  | Return : 'a W.t -> 'a t

let rec to_string : type a. a t -> string = function
  | Apply (w, t) -> w.W.name ^ " -> " ^ to_string t
  | Return w -> w.W.name

let length t =
  let rec loop : type a. a t -> int -> int = fun t acc ->
    match t with
    | Apply (_, t) -> loop t (acc + 1)
    | Return _ -> acc
  in
  loop t 0

let returning w = Return w
let (@->) w t = Apply (w, t)

let python_fn t pyobject =
  let rec loop : type a. a t -> P.t list -> a = fun t acc_args ->
    match t with
    | Apply (w, t) -> (fun x -> loop t (P.Ptr (w.W.wrap x) :: acc_args))
    | Return w -> un w (P.run pyobject acc_args ?kwargs:None)
  in
  loop t []

let id x = x

let ocaml_fn t fn args =
  (* We should check whether [args] is a tuple and if not the case, put it
     in a singleton array. *)
  let args = P.Object.to_array id args in
  if length t <> Array.length args
  then
    Printf.sprintf "expected %d arguments, got %d" (length t) (Array.length args)
    |> failwith;
  let rec loop : type a. a t -> a -> index:int -> P.pyobject = fun t fn ~index ->
    match t with
    | Apply (w, t) -> loop t (fn (un w args.(index))) ~index:(index + 1)
    | Return w -> w.W.wrap fn
  in
  loop t fn ~index:0

module W_impl = struct
  let none =
    let unwrap_exn pyobject =
      if P.Object.is_none pyobject
      then ()
      else failwith "not none"
    in
    { W.wrap = P.Object.none
    ; unwrap_exn
    ; name = "unit"
    }

  let bool =
    { W.wrap = P.Object.from_bool
    ; unwrap_exn = P.Object.to_bool
    ; name = "bool"
    }

  let int =
    { W.wrap = P.PyNumber.create_int
    ; unwrap_exn = P.Object.to_int
    ; name = "int"
    }

  let float =
    { W.wrap = P.PyNumber.create_float
    ; unwrap_exn = P.Object.to_float
    ; name = "float"
    }

  let string =
    { W.wrap = P.PyUnicode.create
    ; unwrap_exn = P.Object.to_string
    ; name = "float"
    }

  let pyobject = { W.wrap = id; unwrap_exn = id; name = "pyobject" }

  let list w =
    let wrap l = P.PyList.create (List.map w.W.wrap l) in
    let unwrap_exn = P.Object.to_list (un w) in
    { W.wrap; unwrap_exn; name = Printf.sprintf "list[%s]" w.W.name }

  let to_array = P.Object.to_array (fun x -> x)

  let tuple2 w1 w2 =
    let wrap (x1, x2) = P.PyTuple.create [| w1.W.wrap x1; w2.W.wrap x2 |] in
    let unwrap_exn o =
      match to_array o with
      | [| o1; o2 |] -> un w1 o1, un w2 o2
      | _ -> failwith "not a tuple2"
    in
    let name = Printf.sprintf "(%s, %s)" w1.W.name w2.W.name in
    { W.wrap; unwrap_exn; name }

  let tuple3 w1 w2 w3 =
    let wrap (x1, x2, x3) =
      P.PyTuple.create [| w1.W.wrap x1; w2.W.wrap x2; w3.W.wrap x3 |]
    in
    let unwrap_exn o =
      match to_array o with
      | [| o1; o2; o3 |] -> un w1 o1, un w2 o2, un w3 o3
      | _ -> failwith "not a tuple3"
    in
    let name = Printf.sprintf "(%s, %s, %s)" w1.W.name w2.W.name w3.W.name in
    { W.wrap; unwrap_exn; name }

  let tuple4 w1 w2 w3 w4 =
    let wrap (x1, x2, x3, x4) =
      P.PyTuple.create [| w1.W.wrap x1; w2.W.wrap x2; w3.W.wrap x3; w4.W.wrap x4 |]
    in
    let unwrap_exn o =
      match to_array o with
      | [| o1; o2; o3; o4 |] -> un w1 o1, un w2 o2, un w3 o3, un w4 o4
      | _ -> failwith "not a tuple4"
    in
    let name = Printf.sprintf "(%s, %s, %s, %s)" w1.W.name w2.W.name w3.W.name w4.W.name in
    { W.wrap; unwrap_exn; name }

  let tuple5 w1 w2 w3 w4 w5 =
    let wrap (x1, x2, x3, x4, x5) =
      P.PyTuple.create
        [| w1.W.wrap x1; w2.W.wrap x2; w3.W.wrap x3; w4.W.wrap x4; w5.W.wrap x5 |]
    in
    let unwrap_exn o =
      match to_array o with
      | [| o1; o2; o3; o4; o5 |] -> un w1 o1, un w2 o2, un w3 o3, un w4 o4, un w5 o5
      | _ -> failwith "not a tuple5"
    in
    let name =
      Printf.sprintf "(%s, %s, %s, %s %s)" w1.W.name w2.W.name w3.W.name w4.W.name w5.W.name
    in
    { W.wrap; unwrap_exn; name }

  let dict w_key w_value =
    let wrap l =
      List.map (fun (k, v) -> w_key.W.wrap k, w_value.W.wrap v) l |> P.PyDict.create
    in
    let unwrap_exn = P.PyDict.items (un w_key) (un w_value) in
    let name = Printf.sprintf "dict[%s: %s]" w_key.W.name w_value.W.name in
    { W.wrap; unwrap_exn; name }
end

include W_impl