corda/sdk/edger8r/linux/CodeGen.ml

1559 lines
63 KiB
OCaml
Raw Normal View History

(*
* Copyright (C) 2011-2016 Intel Corporation. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in
* the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Intel Corporation nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*)
open Printf
open Util (* for failwithf *)
(* --------------------------------------------------------------------
* We first introduce a `parse_enclave_ast' function (see below) to
* parse a value of type `Ast.enclave' into a `enclave_content' record.
* --------------------------------------------------------------------
*)
(* This record type is used to better organize a value of Ast.enclave *)
type enclave_content = {
file_shortnm : string; (* the short name of original EDL file *)
enclave_name : string; (* the normalized C identifier *)
include_list : string list;
import_exprs : Ast.import_decl list;
comp_defs : Ast.composite_type list;
tfunc_decls : Ast.trusted_func list;
ufunc_decls : Ast.untrusted_func list;
}
(* Whether to prefix untrusted proxy with Enclave name *)
let g_use_prefix = ref false
let g_untrusted_dir = ref "."
let g_trusted_dir = ref "."
let empty_ec =
{ file_shortnm = "";
enclave_name = "";
include_list = [];
import_exprs = [];
comp_defs = [];
tfunc_decls = [];
ufunc_decls = []; }
let get_tf_fname (tf: Ast.trusted_func) =
tf.Ast.tf_fdecl.Ast.fname
let is_priv_ecall (tf: Ast.trusted_func) =
tf.Ast.tf_is_priv
let get_uf_fname (uf: Ast.untrusted_func) =
uf.Ast.uf_fdecl.Ast.fname
let get_trusted_func_names (ec: enclave_content) =
List.map get_tf_fname ec.tfunc_decls
let get_untrusted_func_names (ec: enclave_content) =
List.map get_uf_fname ec.ufunc_decls
let tf_list_to_fd_list (tfs: Ast.trusted_func list) =
List.map (fun (tf: Ast.trusted_func) -> tf.Ast.tf_fdecl) tfs
let tf_list_to_priv_list (tfs: Ast.trusted_func list) =
List.map is_priv_ecall tfs
(* Get a list of names of all private ECALLs *)
let get_priv_ecall_names (tfs: Ast.trusted_func list) =
List.filter is_priv_ecall tfs |> List.map get_tf_fname
let uf_list_to_fd_list (ufs: Ast.untrusted_func list) =
List.map (fun (uf: Ast.untrusted_func) -> uf.Ast.uf_fdecl) ufs
(* Get a list of names of all allowed ECALLs from `allow(...)' *)
let get_allowed_names (ufs: Ast.untrusted_func list) =
let allow_lists =
List.map (fun (uf: Ast.untrusted_func) -> uf.Ast.uf_allow_list) ufs
in
List.flatten allow_lists |> dedup_list
(* With `parse_enclave_ast', each enclave AST is traversed only once. *)
let parse_enclave_ast (e: Ast.enclave) =
let ac_include_list = ref [] in
let ac_import_exprs = ref [] in
let ac_comp_defs = ref [] in
let ac_tfunc_decls = ref [] in
let ac_ufunc_decls = ref [] in
List.iter (fun ex ->
match ex with
Ast.Composite x -> ac_comp_defs := x :: !ac_comp_defs
| Ast.Include x -> ac_include_list := x :: !ac_include_list
| Ast.Importing x -> ac_import_exprs := x :: !ac_import_exprs
| Ast.Interface xs ->
List.iter (fun ef ->
match ef with
Ast.Trusted f ->
ac_tfunc_decls := f :: !ac_tfunc_decls
| Ast.Untrusted f ->
ac_ufunc_decls := f :: !ac_ufunc_decls) xs
) e.Ast.eexpr;
{ file_shortnm = e.Ast.ename;
enclave_name = Util.to_c_identifier e.Ast.ename;
include_list = List.rev !ac_include_list;
import_exprs = List.rev !ac_import_exprs;
comp_defs = List.rev !ac_comp_defs;
tfunc_decls = List.rev !ac_tfunc_decls;
ufunc_decls = List.rev !ac_ufunc_decls; }
let is_foreign_array (pt: Ast.parameter_type) =
match pt with
Ast.PTVal _ -> false
| Ast.PTPtr(t, a) ->
match t with
Ast.Foreign _ -> a.Ast.pa_isary
| _ -> false
(* A naked function has neither parameters nor return value. *)
let is_naked_func (fd: Ast.func_decl) =
fd.Ast.rtype = Ast.Void && fd.Ast.plist = []
(*
* If user only defined a trusted function w/o neither parameter nor
* return value, the generated trusted bridge will not call any tRTS
* routines. If the real trusted function doesn't call tRTS function
* either (highly possible), then the MSVC linker will not link tRTS
* into the result enclave.
*)
let tbridge_gen_dummy_variable (ec: enclave_content) =
let _dummy_variable =
sprintf "\n#ifdef _MSC_VER\n\
\t/* In case enclave `%s' doesn't call any tRTS function. */\n\
\tvolatile int force_link_trts = sgx_is_within_enclave(NULL, 0);\n\
\t(void) force_link_trts; /* avoid compiler warning */\n\
#endif\n\n" ec.enclave_name
in
if ec.ufunc_decls <> [] then ""
else
if List.for_all (fun tfd -> is_naked_func tfd.Ast.tf_fdecl) ec.tfunc_decls
then _dummy_variable
else ""
(* This function is used to convert Array form into Pointer form.
* e.g.: int array[10][20] => [count = 200] int* array
*
* This function is called when generating proxy/bridge code and
* the marshaling structure.
*)
let conv_array_to_ptr (pd: Ast.pdecl): Ast.pdecl =
let (pt, declr) = pd in
let get_count_attr ilist =
(* XXX: assume the size of each dimension will be > 0. *)
Ast.ANumber (List.fold_left (fun acc i -> acc*i) 1 ilist)
in
match pt with
Ast.PTVal _ -> (pt, declr)
| Ast.PTPtr(aty, pa) ->
if Ast.is_array declr then
let tmp_declr = { declr with Ast.array_dims = [] } in
let tmp_aty = Ast.Ptr aty in
let tmp_cnt = get_count_attr declr.Ast.array_dims in
let tmp_pa = { pa with Ast.pa_size = { Ast.empty_ptr_size with Ast.ps_count = Some tmp_cnt } }
in (Ast.PTPtr(tmp_aty, tmp_pa), tmp_declr)
else (pt, declr)
(* ------------------------------------------------------------------
* Code generation for edge-routines.
* ------------------------------------------------------------------
*)
(* Little functions for naming of a struct and its members etc *)
let retval_name = "retval"
let retval_declr = { Ast.identifier = retval_name; Ast.array_dims = []; }
let eid_name = "eid"
let ms_ptr_name = "pms"
let ms_struct_val = "ms"
let mk_ms_member_name (pname: string) = "ms_" ^ pname
let mk_ms_struct_name (fname: string) = "ms_" ^ fname ^ "_t"
let ms_retval_name = mk_ms_member_name retval_name
let mk_tbridge_name (fname: string) = "sgx_" ^ fname
let mk_parm_accessor name = sprintf "%s->%s" ms_struct_val (mk_ms_member_name name)
let mk_tmp_var name = "_tmp_" ^ name
let mk_len_var name = "_len_" ^ name
let mk_in_var name = "_in_" ^ name
let mk_ocall_table_name enclave_name = "ocall_table_" ^ enclave_name
(* Un-trusted bridge name is prefixed with enclave file short name. *)
let mk_ubridge_name (file_shortnm: string) (funcname: string) =
sprintf "%s_%s" file_shortnm funcname
let mk_ubridge_proto (file_shortnm: string) (funcname: string) =
sprintf "static sgx_status_t SGX_CDECL %s(void* %s)"
(mk_ubridge_name file_shortnm funcname) ms_ptr_name
(* Common macro definitions. *)
let common_macros = "#include <stdlib.h> /* for size_t */\n\n\
#define SGX_CAST(type, item) ((type)(item))\n\n\
#ifdef __cplusplus\n\
extern \"C\" {\n\
#endif\n"
(* Header footer *)
let header_footer = "\n#ifdef __cplusplus\n}\n#endif /* __cplusplus */\n\n#endif\n"
(* Little functions for generating file names. *)
let get_uheader_short_name (file_shortnm: string) = file_shortnm ^ "_u.h"
let get_uheader_name (file_shortnm: string) =
!g_untrusted_dir ^ separator_str ^ (get_uheader_short_name file_shortnm)
let get_usource_name (file_shortnm: string) =
!g_untrusted_dir ^ separator_str ^ file_shortnm ^ "_u.c"
let get_theader_short_name (file_shortnm: string) = file_shortnm ^ "_t.h"
let get_theader_name (file_shortnm: string) =
!g_trusted_dir ^ separator_str ^ (get_theader_short_name file_shortnm)
let get_tsource_name (file_shortnm: string) =
!g_trusted_dir ^ separator_str ^ file_shortnm ^ "_t.c"
(* Construct the string of structure definition *)
let mk_struct_decl (fs: string) (name: string) =
sprintf "typedef struct %s {\n%s} %s;\n" name fs name
(* Construct the string of union definition *)
let mk_union_decl (fs: string) (name: string) =
sprintf "typedef union %s {\n%s} %s;\n" name fs name
(* Generate a definition of enum *)
let mk_enum_def (e: Ast.enum_def) =
let gen_enum_ele_str (ele: Ast.enum_ele) =
let k, v = ele in
match v with
Ast.EnumValNone -> k
| Ast.EnumVal ev -> sprintf "%s = %s" k (Ast.attr_value_to_string ev)
in
let enname = e.Ast.enname in
let enbody = e.Ast.enbody in
let enbody_str =
if enbody = [] then ""
else List.fold_left (fun acc ele ->
acc ^ "\t" ^ gen_enum_ele_str ele ^ ",\n") "" enbody
in
if enname = "" then sprintf "enum {\n%s};\n" enbody_str
else sprintf "typedef enum %s {\n%s} %s;\n" enname enbody_str enname
let get_array_dims (ns: int list) =
(* Get the array declaration from a list of array dimensions.
* Empty `ns' indicates the corresponding declarator is a simple identifier.
* Element of value -1 means that user does not specify the dimension size.
*)
let get_dim n = if n = -1 then "[]" else sprintf "[%d]" n
in
if ns = [] then ""
else List.fold_left (fun acc n -> acc ^ get_dim n) "" ns
let get_typed_declr_str (ty: Ast.atype) (declr: Ast.declarator) =
let tystr = Ast.get_tystr ty in
let dmstr = get_array_dims declr.Ast.array_dims in
sprintf "%s %s%s" tystr declr.Ast.identifier dmstr
(* Construct a member declaration string *)
let mk_member_decl (ty: Ast.atype) (declr: Ast.declarator) =
sprintf "\t%s;\n" (get_typed_declr_str ty declr)
(* Note that, for a foreign array type `foo_array_t' we will generate
* foo_array_t* ms_field;
* in the marshaling data structure to keep the pass-by-address scheme
* as in the C programming language.
*)
let mk_ms_member_decl (pt: Ast.parameter_type) (declr: Ast.declarator) =
let aty = Ast.get_param_atype pt in
let tystr = Ast.get_tystr aty in
let ptr = if is_foreign_array pt then "* " else "" in
let field = mk_ms_member_name declr.Ast.identifier in
let dmstr = get_array_dims declr.Ast.array_dims in
sprintf "\t%s%s %s%s;\n" tystr ptr field dmstr
(* Generate data structure definition *)
let gen_comp_def (st: Ast.composite_type) =
let gen_member_list mlist =
List.fold_left (fun acc (ty, declr) ->
acc ^ mk_member_decl ty declr) "" mlist
in
match st with
Ast.StructDef s -> mk_struct_decl (gen_member_list s.Ast.mlist) s.Ast.sname
| Ast.UnionDef u -> mk_union_decl (gen_member_list u.Ast.mlist) u.Ast.sname
| Ast.EnumDef e -> mk_enum_def e
(* Generate a list of '#include' *)
let gen_include_list (xs: string list) =
List.fold_left (fun acc s -> acc ^ sprintf "#include \"%s\"\n" s) "" xs
(* Get the type string from 'parameter_type' *)
let get_param_tystr (pt: Ast.parameter_type) =
Ast.get_tystr (Ast.get_param_atype pt)
(* Generate marshaling structure definition *)
let gen_marshal_struct (fd: Ast.func_decl) (errno: string) =
let member_list_str = errno ^
let new_param_list = List.map conv_array_to_ptr fd.Ast.plist in
List.fold_left (fun acc (pt, declr) ->
acc ^ mk_ms_member_decl pt declr) "" new_param_list in
let struct_name = mk_ms_struct_name fd.Ast.fname in
match fd.Ast.rtype with
(* A function w/o return value and parameters doesn't need
a marshaling struct. *)
Ast.Void -> if fd.Ast.plist = [] && errno = "" then ""
else mk_struct_decl member_list_str struct_name
| _ -> let rv_str = mk_ms_member_decl (Ast.PTVal fd.Ast.rtype) retval_declr
in mk_struct_decl (rv_str ^ member_list_str) struct_name
let gen_ecall_marshal_struct (tf: Ast.trusted_func) =
gen_marshal_struct tf.Ast.tf_fdecl ""
let gen_ocall_marshal_struct (uf: Ast.untrusted_func) =
let errno_decl = if uf.Ast.uf_propagate_errno then "\tint ocall_errno;\n" else "" in
gen_marshal_struct uf.Ast.uf_fdecl errno_decl
(* Check whether given parameter is `const' specified. *)
let is_const_ptr (pt: Ast.parameter_type) =
let aty = Ast.get_param_atype pt in
match pt with
Ast.PTVal _ -> false
| Ast.PTPtr(_, pa) ->
if not pa.Ast.pa_rdonly then false
else
match aty with
Ast.Foreign _ -> false
| _ -> true
(* Generate parameter representation. *)
let gen_parm_str (p: Ast.pdecl) =
let (pt, (declr : Ast.declarator)) = p in
let aty = Ast.get_param_atype pt in
let str = get_typed_declr_str aty declr in
if is_const_ptr pt then "const " ^ str else str
(* Generate parameter representation of return value. *)
let gen_parm_retval (rt: Ast.atype) =
if rt = Ast.Void then ""
else Ast.get_tystr rt ^ "* " ^ retval_name
(* ---------------------------------------------------------------------- *)
(* `gen_ecall_table' is used to generate ECALL table with the following form:
SGX_EXTERNC const struct {
size_t nr_ecall; /* number of ECALLs */
struct {
void *ecall_addr;
uint8_t is_priv;
} ecall_table [nr_ecall];
} g_ecall_table = {
2, { {sgx_foo, 1}, {sgx_bar, 0} }
};
*)
let gen_ecall_table (tfs: Ast.trusted_func list) =
let ecall_table_name = "g_ecall_table" in
let ecall_table_size = List.length tfs in
let trusted_fds = tf_list_to_fd_list tfs in
let priv_bits = tf_list_to_priv_list tfs in
let tbridge_names = List.map (fun (fd: Ast.func_decl) ->
mk_tbridge_name fd.Ast.fname) trusted_fds in
let ecall_table =
let bool_to_int b = if b then 1 else 0 in
let inner_table =
List.fold_left2 (fun acc s b ->
sprintf "%s\t\t{(void*)(uintptr_t)%s, %d},\n" acc s (bool_to_int b)) "" tbridge_names priv_bits
in "\t{\n" ^ inner_table ^ "\t}\n"
in
sprintf "SGX_EXTERNC const struct {\n\
\tsize_t nr_ecall;\n\
\tstruct {void* ecall_addr; uint8_t is_priv;} ecall_table[%d];\n\
} %s = {\n\
\t%d,\n\
%s};\n" ecall_table_size
ecall_table_name
ecall_table_size
(if ecall_table_size = 0 then "" else ecall_table)
(* `gen_entry_table' is used to generate Dynamic Entry Table with the form:
SGX_EXTERNC const struct {
/* number of OCALLs (number of ECALLs can be found in ECALL table) */
size_t nr_ocall;
/* entry_table[m][n] = 1 iff. ECALL n is allowed in the OCALL m. */
uint8_t entry_table[NR_OCALL][NR_ECALL];
} g_dyn_entry_table = {
3, {{0, 0}, {0, 1}, {1, 0}}
};
*)
let gen_entry_table (ec: enclave_content) =
let dyn_entry_table_name = "g_dyn_entry_table" in
let ocall_table_size = List.length ec.ufunc_decls in
let trusted_func_names = get_trusted_func_names ec in
let ecall_table_size = List.length trusted_func_names in
let get_entry_array (allowed_ecalls: string list) =
List.fold_left (fun acc name ->
acc ^ (if List.exists (fun x -> x=name) allowed_ecalls
then "1"
else "0") ^ ", ") "" trusted_func_names in
let entry_table =
let inner_table =
List.fold_left (fun acc (uf: Ast.untrusted_func) ->
let entry_array = get_entry_array uf.Ast.uf_allow_list
in acc ^ "\t\t{" ^ entry_array ^ "},\n") "" ec.ufunc_decls
in
"\t{\n" ^ inner_table ^ "\t}\n"
in
(* Generate dynamic entry table iff. both sgx_ecall/ocall_table_size > 0 *)
let gen_table_p = (ecall_table_size > 0) && (ocall_table_size > 0) in
(* When NR_ECALL is 0, or NR_OCALL is 0, there will be no entry table field. *)
let entry_table_field =
if gen_table_p then
sprintf "\tuint8_t entry_table[%d][%d];\n" ocall_table_size ecall_table_size
else
""
in
sprintf "SGX_EXTERNC const struct {\n\
\tsize_t nr_ocall;\n%s\
} %s = {\n\
\t%d,\n\
%s};\n" entry_table_field
dyn_entry_table_name
ocall_table_size
(if gen_table_p then entry_table else "")
(* ---------------------------------------------------------------------- *)
(* Generate the function prototype for untrusted proxy in COM style.
* For example, un-trusted functions
* int foo(double d);
* void bar(float f);
*
* will have an untrusted proxy like below:
* sgx_status_t foo(int* retval, double d);
* sgx_status_t bar(float f);
*)
let gen_tproxy_proto (fd: Ast.func_decl) =
let parm_list =
match fd.Ast.plist with
[] -> ""
| x :: xs ->
List.fold_left (fun acc pd ->
acc ^ ", " ^ gen_parm_str pd) (gen_parm_str x) xs
in
let retval_parm_str = gen_parm_retval fd.Ast.rtype in
if fd.Ast.plist = [] then
sprintf "sgx_status_t SGX_CDECL %s(%s)" fd.Ast.fname retval_parm_str
else if fd.Ast.rtype = Ast.Void then
sprintf "sgx_status_t SGX_CDECL %s(%s)" fd.Ast.fname parm_list
else
sprintf "sgx_status_t SGX_CDECL %s(%s, %s)" fd.Ast.fname retval_parm_str parm_list
(* Generate the function prototype for untrusted proxy in COM style.
* For example, trusted functions
* int foo(double d);
* void bar(float f);
*
* will have an untrusted proxy like below:
* sgx_status_t foo(sgx_enclave_id_t eid, int* retval, double d);
* sgx_status_t foo(sgx_enclave_id_t eid, float f);
*
* When `g_use_prefix' is true, the untrusted proxy name is prefixed
* with the `prefix' parameter.
*
*)
let gen_uproxy_com_proto (fd: Ast.func_decl) (prefix: string) =
let retval_parm_str = gen_parm_retval fd.Ast.rtype in
let eid_parm_str =
if fd.Ast.rtype = Ast.Void then sprintf "(sgx_enclave_id_t %s" eid_name
else sprintf "(sgx_enclave_id_t %s, " eid_name in
let parm_list =
List.fold_left (fun acc pd -> acc ^ ", " ^ gen_parm_str pd)
retval_parm_str fd.Ast.plist in
let fname =
if !g_use_prefix then sprintf "%s_%s" prefix fd.Ast.fname
else fd.Ast.fname
in "sgx_status_t " ^ fname ^ eid_parm_str ^ parm_list ^ ")"
let get_ret_tystr (fd: Ast.func_decl) = Ast.get_tystr fd.Ast.rtype
let get_plist_str (fd: Ast.func_decl) =
if fd.Ast.plist = [] then ""
else List.fold_left (fun acc pd -> acc ^ ", " ^ gen_parm_str pd)
(gen_parm_str (List.hd fd.Ast.plist))
(List.tl fd.Ast.plist)
(* Generate the function prototype as is. *)
let gen_func_proto (fd: Ast.func_decl) =
let ret_tystr = get_ret_tystr fd in
let plist_str = get_plist_str fd in
sprintf "%s %s(%s)" ret_tystr fd.Ast.fname plist_str
(* Generate prototypes for untrusted function. *)
let gen_ufunc_proto (uf: Ast.untrusted_func) =
let dllimport = if uf.Ast.uf_fattr.Ast.fa_dllimport then "SGX_DLLIMPORT " else "" in
let ret_tystr = get_ret_tystr uf.Ast.uf_fdecl in
let cconv_str = "SGX_" ^ Ast.get_call_conv_str uf.Ast.uf_fattr.Ast.fa_convention in
let func_name = uf.Ast.uf_fdecl.Ast.fname in
let plist_str = get_plist_str uf.Ast.uf_fdecl in
sprintf "%s%s SGX_UBRIDGE(%s, %s, (%s))"
dllimport ret_tystr cconv_str func_name plist_str
(* The preemble contains common include expressions. *)
let gen_uheader_preemble (guard: string) (inclist: string)=
let grd_hdr = sprintf "#ifndef %s\n#define %s\n\n" guard guard in
let inc_exp = "#include <stdint.h>\n\
#include <wchar.h>\n\
#include <stddef.h>\n\
#include <string.h>\n\
#include \"sgx_edger8r.h\" /* for sgx_satus_t etc. */\n" in
grd_hdr ^ inc_exp ^ "\n" ^ inclist ^ "\n" ^ common_macros
let ms_writer out_chan ec =
let ms_struct_ecall = List.map gen_ecall_marshal_struct ec.tfunc_decls in
let ms_struct_ocall = List.map gen_ocall_marshal_struct ec.ufunc_decls in
List.iter (fun s -> output_string out_chan (s ^ "\n")) ms_struct_ecall;
List.iter (fun s -> output_string out_chan (s ^ "\n")) ms_struct_ocall
(* Generate untrusted header for enclave *)
let gen_untrusted_header (ec: enclave_content) =
let header_fname = get_uheader_name ec.file_shortnm in
let guard_macro = sprintf "%s_U_H__" (String.uppercase ec.enclave_name) in
let preemble_code =
let include_list = gen_include_list (ec.include_list @ !untrusted_headers) in
gen_uheader_preemble guard_macro include_list
in
let comp_def_list = List.map gen_comp_def ec.comp_defs in
let func_proto_ufunc = List.map gen_ufunc_proto ec.ufunc_decls in
let uproxy_com_proto =
List.map (fun (tf: Ast.trusted_func) ->
gen_uproxy_com_proto tf.Ast.tf_fdecl ec.enclave_name)
ec.tfunc_decls
in
let out_chan = open_out header_fname in
output_string out_chan (preemble_code ^ "\n");
List.iter (fun s -> output_string out_chan (s ^ "\n")) comp_def_list;
List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_proto_ufunc;
output_string out_chan "\n";
List.iter (fun s -> output_string out_chan (s ^ ";\n")) uproxy_com_proto;
output_string out_chan header_footer;
close_out out_chan
(* It generates preemble for trusted header file. *)
let gen_theader_preemble (guard: string) (inclist: string) =
let grd_hdr = sprintf "#ifndef %s\n#define %s\n\n" guard guard in
let inc_exp = "#include <stdint.h>\n\
#include <wchar.h>\n\
#include <stddef.h>\n\
#include \"sgx_edger8r.h\" /* for sgx_ocall etc. */\n\n" in
grd_hdr ^ inc_exp ^ inclist ^ "\n" ^ common_macros
(* Generate function prototype for functions used by `sizefunc' attribute. *)
let gen_sizefunc_proto out_chan (ec: enclave_content) =
let tfunc_decls = tf_list_to_fd_list ec.tfunc_decls in
let ufunc_decls = uf_list_to_fd_list ec.ufunc_decls in
let dict = Hashtbl.create 4 in
let get_sizefunc_proto s =
let (pt, ns) = Hashtbl.find dict s in
let tmpdeclr = { Ast.identifier = "val"; Ast.array_dims = ns; } in
sprintf "size_t %s(const %s);\n" s (get_typed_declr_str pt tmpdeclr)
in
let add_item (fname: string) (ty: Ast.atype * int list) =
try
let v = Hashtbl.find dict fname
in
if v <> ty then
failwithf "`%s' requires different parameter types" fname
with Not_found -> Hashtbl.add dict fname ty
in
let fill_dict (pd: Ast.pdecl) =
let (pt, declr) = pd in
match pt with
Ast.PTVal _ -> ()
| Ast.PTPtr(aty, pattr) ->
match pattr.Ast.pa_size.Ast.ps_sizefunc with
Some s -> add_item s (aty, declr.Ast.array_dims)
| _ -> ()
in
List.iter (fun (fd: Ast.func_decl) ->
List.iter fill_dict fd.Ast.plist) (tfunc_decls @ ufunc_decls);
Hashtbl.iter (fun x y ->
output_string out_chan (get_sizefunc_proto x)) dict;
output_string out_chan "\n"
(* Generate trusted header for enclave *)
let gen_trusted_header (ec: enclave_content) =
let header_fname = get_theader_name ec.file_shortnm in
let guard_macro = sprintf "%s_T_H__" (String.uppercase ec.enclave_name) in
let guard_code =
let include_list = gen_include_list (ec.include_list @ !trusted_headers) in
gen_theader_preemble guard_macro include_list in
let comp_def_list = List.map gen_comp_def ec.comp_defs in
let func_proto_list = List.map gen_func_proto (tf_list_to_fd_list ec.tfunc_decls) in
let func_tproxy_list= List.map gen_tproxy_proto (uf_list_to_fd_list ec.ufunc_decls) in
let out_chan = open_out header_fname in
output_string out_chan (guard_code ^ "\n");
List.iter (fun s -> output_string out_chan (s ^ "\n")) comp_def_list;
gen_sizefunc_proto out_chan ec;
List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_proto_list;
output_string out_chan "\n";
List.iter (fun s -> output_string out_chan (s ^ ";\n")) func_tproxy_list;
output_string out_chan header_footer;
close_out out_chan
(* It generates function invocation expression. *)
let mk_parm_name_raw (pt: Ast.parameter_type) (declr: Ast.declarator) =
let cast_expr =
if Ast.is_array declr && List.length declr.Ast.array_dims > 1
then
let tystr = get_param_tystr pt in
let dims = get_array_dims (List.tl declr.Ast.array_dims) in
sprintf "(%s (*)%s)" tystr dims
else ""
in
cast_expr ^ mk_parm_accessor declr.Ast.identifier
(* We passed foreign array `foo_array_t foo' as `&foo[0]', thus we
* need to get back `foo' by '* array_ptr' where
* array_ptr = &foo[0]
*)
let add_foreign_array_ptrref
(f: Ast.parameter_type -> Ast.declarator -> string)
(pt: Ast.parameter_type)
(declr: Ast.declarator) =
let arg = f pt declr in
if is_foreign_array pt
then sprintf "(%s != NULL) ? (*%s) : NULL" arg arg
else arg
let mk_parm_name_ubridge (pt: Ast.parameter_type) (declr: Ast.declarator) =
add_foreign_array_ptrref mk_parm_name_raw pt declr
let mk_parm_name_ext (pt: Ast.parameter_type) (declr: Ast.declarator) =
let name = declr.Ast.identifier in
match pt with
Ast.PTVal _ -> mk_parm_name_raw pt declr
| Ast.PTPtr (_, attr) ->
match attr.Ast.pa_direction with
| Ast.PtrNoDirection -> mk_parm_name_raw pt declr
| _ -> mk_in_var name
let gen_func_invoking (fd: Ast.func_decl)
(mk_parm_name: Ast.parameter_type -> Ast.declarator -> string) =
let gen_parm_str pt declr =
let parm_name = mk_parm_name pt declr in
let tystr = get_param_tystr pt in
if is_const_ptr pt then sprintf "(const %s)%s" tystr parm_name else parm_name
in
match fd.Ast.plist with
[] -> sprintf "%s();" fd.Ast.fname
| (pt, (declr : Ast.declarator)) :: ps ->
sprintf "%s(%s);"
fd.Ast.fname
(let p0 = gen_parm_str pt declr in
List.fold_left (fun acc (pty, dlr) ->
acc ^ ", " ^ gen_parm_str pty dlr) p0 ps)
(* Generate untrusted bridge code for a given untrusted function. *)
let gen_func_ubridge (file_shortnm: string) (ufunc: Ast.untrusted_func) =
let fd = ufunc.Ast.uf_fdecl in
let propagate_errno = ufunc.Ast.uf_propagate_errno in
let func_open = sprintf "%s\n{\n" (mk_ubridge_proto file_shortnm fd.Ast.fname) in
let func_close = "\treturn SGX_SUCCESS;\n}\n" in
let set_errno = if propagate_errno then "\tms->ocall_errno = errno;" else "" in
let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
let declare_ms_ptr = sprintf "%s* %s = SGX_CAST(%s*, %s);"
ms_struct_name
ms_struct_val
ms_struct_name
ms_ptr_name in
let call_with_pms =
let invoke_func = gen_func_invoking fd mk_parm_name_ubridge in
if fd.Ast.rtype = Ast.Void then invoke_func
else sprintf "%s = %s" (mk_parm_accessor retval_name) invoke_func
in
if (is_naked_func fd) && (propagate_errno = false) then
let check_pms =
sprintf "if (%s != NULL) return SGX_ERROR_INVALID_PARAMETER;" ms_ptr_name
in
sprintf "%s\t%s\n\t%s\n%s" func_open check_pms call_with_pms func_close
else
sprintf "%s\t%s\n\t%s\n%s\n%s" func_open declare_ms_ptr call_with_pms set_errno func_close
let fill_ms_field (isptr: bool) (pd: Ast.pdecl) =
let accessor = if isptr then "->" else "." in
let (pt, declr) = pd in
let param_name = declr.Ast.identifier in
let ms_member_name = mk_ms_member_name param_name in
let assignment_str (use_cast: bool) (aty: Ast.atype) =
let cast_str = if use_cast then sprintf "(%s)" (Ast.get_tystr aty) else ""
in
sprintf "%s%s%s = %s%s;" ms_struct_val accessor ms_member_name cast_str param_name
in
let gen_setup_foreign_array aty =
sprintf "%s%s%s = (%s *)&%s[0];"
ms_struct_val accessor ms_member_name (Ast.get_tystr aty) param_name
in
if declr.Ast.array_dims = [] then
match pt with
Ast.PTVal(aty) -> assignment_str false aty
| Ast.PTPtr(aty, pattr) ->
if pattr.Ast.pa_isary
then gen_setup_foreign_array aty
else
if pattr.Ast.pa_rdonly then assignment_str true aty
else assignment_str false aty
else
(* Arrays are passed by address. *)
let tystr = Ast.get_tystr (Ast.Ptr (Ast.get_param_atype pt)) in
sprintf "%s%s%s = (%s)%s;" ms_struct_val accessor ms_member_name tystr param_name
(* Generate untrusted proxy code for a given trusted function. *)
let gen_func_uproxy (fd: Ast.func_decl) (idx: int) (ec: enclave_content) =
let func_open =
gen_uproxy_com_proto fd ec.enclave_name ^
"\n{\n\tsgx_status_t status;\n"
in
let func_close = "\treturn status;\n}\n" in
let ocall_table_name = mk_ocall_table_name ec.enclave_name in
let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
let declare_ms_expr = sprintf "%s %s;" ms_struct_name ms_struct_val in
let ocall_table_ptr =
sprintf "&%s" ocall_table_name in
(* Normal case - do ECALL with marshaling structure*)
let ecall_with_ms = sprintf "status = sgx_ecall(%s, %d, %s, &%s);"
eid_name idx ocall_table_ptr ms_struct_val in
(* Rare case - the trusted function doesn't have parameter nor return value.
* In this situation, no marshaling structure is required - passing in NULL.
*)
let ecall_null = sprintf "status = sgx_ecall(%s, %d, %s, NULL);"
eid_name idx ocall_table_ptr
in
let update_retval = sprintf "if (status == SGX_SUCCESS && %s) *%s = %s.%s;"
retval_name retval_name ms_struct_val ms_retval_name in
let func_body = ref [] in
if is_naked_func fd then
sprintf "%s\t%s\n%s" func_open ecall_null func_close
else
begin
func_body := declare_ms_expr :: !func_body;
List.iter (fun pd -> func_body := fill_ms_field false pd :: !func_body) fd.Ast.plist;
func_body := ecall_with_ms :: !func_body;
if fd.Ast.rtype <> Ast.Void then func_body := update_retval :: !func_body;
List.fold_left (fun acc s -> acc ^ "\t" ^ s ^ "\n") func_open (List.rev !func_body) ^ func_close
end
(* Generate an expression to check the pointers. *)
let mk_check_ptr (name: string) (lenvar: string) =
let checker = "CHECK_UNIQUE_POINTER"
in sprintf "\t%s(%s, %s);\n" checker name lenvar
(* Pointer to marshaling structure should never be NULL. *)
let mk_check_pms (fname: string) =
let lenvar = sprintf "sizeof(%s)" (mk_ms_struct_name fname)
in sprintf "\t%s(%s, %s);\n" "CHECK_REF_POINTER" ms_ptr_name lenvar
(* Generate code to get the size of the pointer. *)
let gen_ptr_size (ty: Ast.atype) (pattr: Ast.ptr_attr) (name: string) (get_parm: string -> string) =
let len_var = mk_len_var name in
let parm_name = get_parm name in
let mk_len_size v =
match v with
Ast.AString s -> get_parm s
| Ast.ANumber n -> sprintf "%d" n in
let mk_len_count v size_str =
match v with
Ast.AString s -> sprintf "%s * %s" (get_parm s) size_str
| Ast.ANumber n -> sprintf "%d * %s" n size_str in
let mk_len_sizefunc s = sprintf "((%s) ? %s(%s) : 0)" parm_name s parm_name in
(* Note, during the parsing stage, we already eliminated the case that
* user specified both 'size' and 'sizefunc' attribute.
*)
let do_attribute (pattr: Ast.ptr_attr) =
let do_ps_attribute (sattr: Ast.ptr_size) =
let size_str =
match sattr.Ast.ps_size with
Some a -> mk_len_size a
| None ->
match sattr.Ast.ps_sizefunc with
None -> sprintf "sizeof(*%s)" parm_name
| Some a -> mk_len_sizefunc a
in
match sattr.Ast.ps_count with
None -> size_str
| Some a -> mk_len_count a size_str
in
if pattr.Ast.pa_isstr then
sprintf "%s ? strlen(%s) + 1 : 0" parm_name parm_name
else if pattr.Ast.pa_iswstr then
sprintf "%s ? (wcslen(%s) + 1) * sizeof(wchar_t) : 0" parm_name parm_name
else
do_ps_attribute pattr.Ast.pa_size
in
sprintf "size_t %s = %s;\n"
len_var
(if pattr.Ast.pa_isary
then sprintf "sizeof(%s)" (Ast.get_tystr ty)
else do_attribute pattr)
(* Find the data type of a parameter. *)
let find_param_type (name: string) (plist: Ast.pdecl list) =
try
let (pt, _) = List.find (fun (pd: Ast.pdecl) ->
let (pt, declr) = pd
in declr.Ast.identifier = name) plist
in get_param_tystr pt
with
Not_found -> failwithf "parameter `%s' not found." name
(* Generate code to check the length of buffers. *)
let gen_check_tbridge_length_overflow (plist: Ast.pdecl list) =
let gen_check_length (ty: Ast.atype) (attr: Ast.ptr_attr) (declr: Ast.declarator) =
let name = declr.Ast.identifier in
let tmp_ptr_name= mk_tmp_var name in
let mk_len_size v =
match v with
Ast.AString s -> mk_tmp_var s
| Ast.ANumber n -> sprintf "%d" n in
let mk_len_sizefunc s = sprintf "((%s) ? %s(%s) : 0)" tmp_ptr_name s tmp_ptr_name in
let gen_check_overflow cnt size_str =
let if_statement =
match cnt with
Ast.AString s -> sprintf "\tif ((size_t)%s > (SIZE_MAX / %s)) {\n" (mk_tmp_var s) size_str
| Ast.ANumber n -> sprintf "\tif (%d > (SIZE_MAX / %s)) {\n" n size_str
in
sprintf "%s\t\tstatus = SGX_ERROR_INVALID_PARAMETER;\n\t\tgoto err;\n\t}" if_statement
in
let size_str =
match attr.Ast.pa_size.Ast.ps_size with
Some a -> mk_len_size a
| None ->
match attr.Ast.pa_size.Ast.ps_sizefunc with
None -> sprintf "sizeof(*%s)" tmp_ptr_name
| Some a -> mk_len_sizefunc a
in
match attr.Ast.pa_size.Ast.ps_count with
None -> ""
| Some a -> sprintf "%s\n\n" (gen_check_overflow a size_str)
in
List.fold_left
(fun acc (pty, declr) ->
match pty with
Ast.PTVal _ -> acc
| Ast.PTPtr(ty, attr) -> acc ^ gen_check_length ty attr declr) "" plist
(* Generate code to check all function parameters which are pointers. *)
let gen_check_tbridge_ptr_parms (plist: Ast.pdecl list) =
let gen_check_ptr (ty: Ast.atype) (pattr: Ast.ptr_attr) (declr: Ast.declarator) =
if not pattr.Ast.pa_chkptr then ""
else
let name = declr.Ast.identifier in
let len_var = mk_len_var name in
let parm_name = mk_tmp_var name in
if pattr.Ast.pa_chkptr
then mk_check_ptr parm_name len_var
else ""
in
let new_param_list = List.map conv_array_to_ptr plist
in
List.fold_left
(fun acc (pty, declr) ->
match pty with
Ast.PTVal _ -> acc
| Ast.PTPtr(ty, attr) -> acc ^ gen_check_ptr ty attr declr) "" new_param_list
(* If a foreign type is a readonly pointer, we cast it to 'void*' for memcpy() and free() *)
let mk_in_ptr_dst_name (rdonly: bool) (ptr_name: string) =
if rdonly then "(void*)" ^ ptr_name
else ptr_name
(* Generate the code to handle function pointer parameter direction,
* which is to be inserted before actually calling the trusted function.
*)
let gen_parm_ptr_direction_pre (plist: Ast.pdecl list) =
let clone_in_ptr (ty: Ast.atype) (attr: Ast.ptr_attr) (declr: Ast.declarator) =
let name = declr.Ast.identifier in
let is_ary = (Ast.is_array declr || attr.Ast.pa_isary) in
let in_ptr_name = mk_in_var name in
let in_ptr_type = sprintf "%s%s" (Ast.get_tystr ty) (if is_ary then "*" else "") in
let len_var = mk_len_var name in
let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly in_ptr_name in
let tmp_ptr_name= mk_tmp_var name in
let check_sizefunc_ptr (fn: string) =
sprintf "\t\t/* check whether the pointer is modified. */\n\
\t\tif (%s(%s) != %s) {\n\
\t\t\tstatus = SGX_ERROR_INVALID_PARAMETER;\n\
\t\t\tgoto err;\n\
\t\t}" fn in_ptr_name len_var
in
let malloc_and_copy pre_indent =
match attr.Ast.pa_direction with
Ast.PtrIn | Ast.PtrInOut ->
let code_template = [
sprintf "if (%s != NULL) {" tmp_ptr_name;
sprintf "\t%s = (%s)malloc(%s);" in_ptr_name in_ptr_type len_var;
sprintf "\tif (%s == NULL) {" in_ptr_name;
"\t\tstatus = SGX_ERROR_OUT_OF_MEMORY;";
"\t\tgoto err;";
"\t}\n";
sprintf "\tmemcpy(%s, %s, %s);" in_ptr_dst_name tmp_ptr_name len_var;
]
in
let s1 = List.fold_left (fun acc s -> acc ^ pre_indent ^ s ^ "\n") "" code_template in
let s2 =
if attr.Ast.pa_isstr
then sprintf "%s\t\t%s[%s - 1] = '\\0';\n" s1 in_ptr_name len_var
else if attr.Ast.pa_iswstr
then sprintf "%s\t\t%s[(%s - sizeof(wchar_t))/sizeof(wchar_t)] = (wchar_t)0;\n" s1 in_ptr_name len_var
else s1 in
let s3 =
match attr.Ast.pa_size.Ast.ps_sizefunc with
None -> s2
| Some s -> sprintf "%s\n%s\n" s2 (check_sizefunc_ptr(s))
in sprintf "%s\t}\n" s3
| Ast.PtrOut ->
let code_template = [
sprintf "if (%s != NULL) {" tmp_ptr_name;
sprintf "\tif ((%s = (%s)malloc(%s)) == NULL) {" in_ptr_name in_ptr_type len_var;
"\t\tstatus = SGX_ERROR_OUT_OF_MEMORY;";
"\t\tgoto err;";
"\t}\n";
sprintf "\tmemset((void*)%s, 0, %s);" in_ptr_name len_var;
"}"]
in
List.fold_left (fun acc s -> acc ^ pre_indent ^ s ^ "\n") "" code_template
| _ -> ""
in
malloc_and_copy "\t"
in List.fold_left
(fun acc (pty, declr) ->
match pty with
Ast.PTVal _ -> acc
| Ast.PTPtr (ty, attr) -> acc ^ clone_in_ptr ty attr declr) "" plist
(* Generate the code to handle function pointer parameter direction,
* which is to be inserted after finishing calling the trusted function.
*)
let gen_parm_ptr_direction_post (plist: Ast.pdecl list) =
let copy_and_free (attr: Ast.ptr_attr) (declr: Ast.declarator) =
let name = declr.Ast.identifier in
let in_ptr_name = mk_in_var name in
let len_var = mk_len_var name in
let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly in_ptr_name in
match attr.Ast.pa_direction with
Ast.PtrIn -> sprintf "\tif (%s) free(%s);\n" in_ptr_name in_ptr_dst_name
| Ast.PtrInOut | Ast.PtrOut ->
sprintf "\tif (%s) {\n\t\tmemcpy(%s, %s, %s);\n\t\tfree(%s);\n\t}\n"
in_ptr_name
(mk_tmp_var name)
in_ptr_name
len_var
in_ptr_name
| _ -> ""
in List.fold_left
(fun acc (pty, declr) ->
match pty with
Ast.PTVal _ -> acc
| Ast.PTPtr (ty, attr) -> acc ^ copy_and_free attr declr) "" plist
(* Generate an "err:" goto mark if necessary. *)
let gen_err_mark (plist: Ast.pdecl list) =
let has_inout_p (attr: Ast.ptr_attr): bool =
attr.Ast.pa_direction <> Ast.PtrNoDirection
in
if List.exists (fun (pt, name) ->
match pt with
Ast.PTVal _ -> false
| Ast.PTPtr(_, attr) -> has_inout_p attr) plist
then "err:"
else ""
(* It is used to save the parameters used as the value of size/count attribute. *)
let param_cache = Hashtbl.create 1
let is_in_param_cache s = Hashtbl.mem param_cache s
(* Try to generate a temporary value to save the size of the buffer. *)
let gen_tmp_size (pattr: Ast.ptr_attr) (plist: Ast.pdecl list) =
let do_gen_temp_var (s: string) =
if is_in_param_cache s then ""
else
let param_tystr = find_param_type s plist in
let tmp_var = mk_tmp_var s in
let parm_str = mk_parm_accessor s in
Hashtbl.add param_cache s true;
sprintf "\t%s %s = %s;\n" param_tystr tmp_var parm_str
in
let gen_temp_var (v: Ast.attr_value) =
match v with
Ast.ANumber _ -> ""
| Ast.AString s -> do_gen_temp_var s
in
let tmp_size_str =
match pattr.Ast.pa_size.Ast.ps_size with
Some v -> gen_temp_var v
| None -> ""
in
let tmp_count_str =
match pattr.Ast.pa_size.Ast.ps_count with
Some v -> gen_temp_var v
| None -> ""
in
sprintf "%s%s" tmp_size_str tmp_count_str
let is_ptr (pt: Ast.parameter_type) =
match pt with
Ast.PTVal _ -> false
| Ast.PTPtr _ -> true
let is_ptr_type (aty: Ast.atype) =
match aty with
Ast.Ptr _ -> true
| _ -> false
let ptr_has_direction (pt: Ast.parameter_type) =
match pt with
Ast.PTVal _ -> false
| Ast.PTPtr(_, a) -> a.Ast.pa_direction <> Ast.PtrNoDirection
let tbridge_mk_parm_name_ext (pt: Ast.parameter_type) (declr: Ast.declarator) =
if is_in_param_cache declr.Ast.identifier || (is_ptr pt && (not (is_foreign_array pt)))
then
if ptr_has_direction pt
then mk_in_var declr.Ast.identifier
else mk_tmp_var declr.Ast.identifier
else mk_parm_name_ext pt declr
let mk_parm_name_tbridge (pt: Ast.parameter_type) (declr: Ast.declarator) =
add_foreign_array_ptrref tbridge_mk_parm_name_ext pt declr
(* Generate local variables required for the trusted bridge. *)
let gen_tbridge_local_vars (plist: Ast.pdecl list) =
let status_var = "\tsgx_status_t status = SGX_SUCCESS;\n" in
let do_gen_local_var (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
let tmp_var =
(* Save a copy of pointer in case it might be modified in the marshaling structure. *)
sprintf "\t%s %s = %s;\n" (Ast.get_tystr ty) (mk_tmp_var name) (mk_parm_accessor name)
in
let len_var =
if not attr.Ast.pa_chkptr then ""
else gen_tmp_size attr plist ^ "\t" ^ gen_ptr_size ty attr name mk_tmp_var in
let in_ptr =
match attr.Ast.pa_direction with
Ast.PtrNoDirection -> ""
| _ -> sprintf "\t%s %s = NULL;\n" (Ast.get_tystr ty) (mk_in_var name)
in
tmp_var ^ len_var ^ in_ptr
in
let gen_local_var_for_foreign_array (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
let tystr = Ast.get_tystr ty in
let tmp_var =
sprintf "\t%s* %s = %s;\n" tystr (mk_tmp_var name) (mk_parm_accessor name)
in
let len_var = sprintf "\tsize_t %s = sizeof(%s);\n" (mk_len_var name) tystr
in
let in_ptr = sprintf "\t%s* %s = NULL;\n" tystr (mk_in_var name)
in
match attr.Ast.pa_direction with
Ast.PtrNoDirection -> ""
| _ -> tmp_var ^ len_var ^ in_ptr
in
let gen_local_var (pd: Ast.pdecl) =
let (pty, declr) = pd in
match pty with
Ast.PTVal _ -> ""
| Ast.PTPtr (ty, attr) ->
if is_foreign_array pty
then gen_local_var_for_foreign_array ty attr declr.Ast.identifier
else do_gen_local_var ty attr declr.Ast.identifier
in
let new_param_list = List.map conv_array_to_ptr plist
in
Hashtbl.clear param_cache;
List.fold_left (fun acc pd -> acc ^ gen_local_var pd) status_var new_param_list
(* It generates trusted bridge code for a trusted function. *)
let gen_func_tbridge (fd: Ast.func_decl) (dummy_var: string) =
let func_open = sprintf "static sgx_status_t SGX_CDECL %s(void* %s)\n{\n"
(mk_tbridge_name fd.Ast.fname)
ms_ptr_name in
let local_vars = gen_tbridge_local_vars fd.Ast.plist in
let func_close = "\treturn status;\n}\n" in
let ms_struct_name = mk_ms_struct_name fd.Ast.fname in
let declare_ms_ptr = sprintf "%s* %s = SGX_CAST(%s*, %s);"
ms_struct_name
ms_struct_val
ms_struct_name
ms_ptr_name in
let invoke_func = gen_func_invoking fd mk_parm_name_tbridge in
let update_retval = sprintf "%s = %s"
(mk_parm_accessor retval_name)
invoke_func in
if is_naked_func fd then
let check_pms =
sprintf "if (%s != NULL) return SGX_ERROR_INVALID_PARAMETER;" ms_ptr_name
in
sprintf "%s%s%s\t%s\n\t%s\n%s" func_open local_vars dummy_var check_pms invoke_func func_close
else
sprintf "%s\t%s\n%s\n%s%s%s\n%s\t%s\n%s\n%s\n%s"
func_open
declare_ms_ptr
local_vars
(gen_check_tbridge_length_overflow fd.Ast.plist)
(mk_check_pms fd.Ast.fname)
(gen_check_tbridge_ptr_parms fd.Ast.plist)
(gen_parm_ptr_direction_pre fd.Ast.plist)
(if fd.Ast.rtype <> Ast.Void then update_retval else invoke_func)
(gen_err_mark fd.Ast.plist)
(gen_parm_ptr_direction_post fd.Ast.plist)
func_close
let tproxy_fill_ms_field (pd: Ast.pdecl) =
let (pt, declr) = pd in
let name = declr.Ast.identifier in
let len_var = mk_len_var name in
let parm_accessor = mk_parm_accessor name in
match pt with
Ast.PTVal _ -> fill_ms_field true pd
| Ast.PTPtr(ty, attr) ->
let is_ary = (Ast.is_array declr || attr.Ast.pa_isary) in
let tystr = sprintf "%s%s" (get_param_tystr pt) (if is_ary then "*" else "") in
if is_ary && is_ptr_type ty then
sprintf "\n#pragma message(\"Pointer array `%s' in trusted proxy `\"\
__FUNCTION__ \"' is dangerous. No code generated.\")\n" name
else
let in_ptr_dst_name = mk_in_ptr_dst_name attr.Ast.pa_rdonly parm_accessor in
if not attr.Ast.pa_chkptr (* [user_check] specified *)
then sprintf "%s = SGX_CAST(%s, %s);" parm_accessor tystr name
else
match attr.Ast.pa_direction with
Ast.PtrOut ->
let code_template =
[sprintf "if (%s != NULL && sgx_is_within_enclave(%s, %s)) {" name name len_var;
sprintf "\t%s = (%s)__tmp;" parm_accessor tystr;
sprintf "\t__tmp = (void *)((size_t)__tmp + %s);" len_var;
sprintf "\tmemset(%s, 0, %s);" in_ptr_dst_name len_var;
sprintf "} else if (%s == NULL) {" name;
sprintf "\t%s = NULL;" parm_accessor;
"} else {";
"\tsgx_ocfree();";
"\treturn SGX_ERROR_INVALID_PARAMETER;";
"}"
]
in List.fold_left (fun acc s -> acc ^ s ^ "\n\t") "" code_template
| _ ->
let code_template =
[sprintf "if (%s != NULL && sgx_is_within_enclave(%s, %s)) {" name name len_var;
sprintf "\t%s = (%s)__tmp;" parm_accessor tystr;
sprintf "\t__tmp = (void *)((size_t)__tmp + %s);" len_var;
sprintf "\tmemcpy(%s, %s, %s);" in_ptr_dst_name name len_var;
sprintf "} else if (%s == NULL) {" name;
sprintf "\t%s = NULL;" parm_accessor;
"} else {";
"\tsgx_ocfree();";
"\treturn SGX_ERROR_INVALID_PARAMETER;";
"}"
]
in List.fold_left (fun acc s -> acc ^ s ^ "\n\t") "" code_template
(* Generate local variables required for the trusted proxy. *)
let gen_tproxy_local_vars (plist: Ast.pdecl list) =
let status_var = "sgx_status_t status = SGX_SUCCESS;\n" in
let do_gen_local_var (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
if not attr.Ast.pa_chkptr then ""
else "\t" ^ gen_ptr_size ty attr name (fun x -> x)
in
let gen_local_var (pd: Ast.pdecl) =
let (pty, declr) = pd in
match pty with
Ast.PTVal _ -> ""
| Ast.PTPtr (ty, attr) -> do_gen_local_var ty attr declr.Ast.identifier
in
let new_param_list = List.map conv_array_to_ptr plist
in
List.fold_left (fun acc pd -> acc ^ gen_local_var pd) status_var new_param_list
(* Generate only one ocalloc block required for the trusted proxy. *)
let gen_ocalloc_block (fname: string) (plist: Ast.pdecl list) =
let ms_struct_name = mk_ms_struct_name fname in
let local_vars_block = sprintf "%s* %s = NULL;\n\tsize_t ocalloc_size = sizeof(%s);\n\tvoid *__tmp = NULL;\n\n" ms_struct_name ms_struct_val ms_struct_name in
let count_ocalloc_size (ty: Ast.atype) (attr: Ast.ptr_attr) (name: string) =
if not attr.Ast.pa_chkptr then ""
else sprintf "\tocalloc_size += (%s != NULL && sgx_is_within_enclave(%s, %s)) ? %s : 0;\n" name name (mk_len_var name) (mk_len_var name)
in
let do_count_ocalloc_size (pd: Ast.pdecl) =
let (pty, declr) = pd in
match pty with
Ast.PTVal _ -> ""
| Ast.PTPtr (ty, attr) -> count_ocalloc_size ty attr declr.Ast.identifier
in
let do_gen_ocalloc_block = [
"\n\t__tmp = sgx_ocalloc(ocalloc_size);\n";
"\tif (__tmp == NULL) {\n";
"\t\tsgx_ocfree();\n";
"\t\treturn SGX_ERROR_UNEXPECTED;\n";
"\t}\n";
sprintf "\t%s = (%s*)__tmp;\n" ms_struct_val ms_struct_name;
sprintf "\t__tmp = (void *)((size_t)__tmp + sizeof(%s));\n" ms_struct_name;
]
in
let new_param_list = List.map conv_array_to_ptr plist
in
let s1 = List.fold_left (fun acc pd -> acc ^ do_count_ocalloc_size pd) local_vars_block new_param_list in
List.fold_left (fun acc s -> acc ^ s) s1 do_gen_ocalloc_block
(* Generate trusted proxy code for a given untrusted function. *)
let gen_func_tproxy (ufunc: Ast.untrusted_func) (idx: int) =
let fd = ufunc.Ast.uf_fdecl in
let propagate_errno = ufunc.Ast.uf_propagate_errno in
let func_open = sprintf "%s\n{\n" (gen_tproxy_proto fd) in
let local_vars = gen_tproxy_local_vars fd.Ast.plist in
let ocalloc_ms_struct = gen_ocalloc_block fd.Ast.fname fd.Ast.plist in
let gen_ocfree rtype plist =
if rtype = Ast.Void && plist = [] then "" else "\tsgx_ocfree();\n"
in
let handle_out_ptr plist =
let copy_memory (attr: Ast.ptr_attr) (declr: Ast.declarator) =
let name = declr.Ast.identifier in
match attr.Ast.pa_direction with
Ast.PtrInOut | Ast.PtrOut ->
sprintf "\tif (%s) memcpy((void*)%s, %s, %s);\n" name name (mk_parm_accessor name) (mk_len_var name)
| _ -> ""
in List.fold_left (fun acc (pty, declr) ->
match pty with
Ast.PTVal _ -> acc
| Ast.PTPtr(ty, attr) -> acc ^ copy_memory attr declr) "" plist in
let set_errno = if propagate_errno then "\terrno = ms->ocall_errno;" else "" in
let func_close = sprintf "%s%s\n%s%s\n"
(handle_out_ptr fd.Ast.plist)
set_errno
(gen_ocfree fd.Ast.rtype fd.Ast.plist)
"\treturn status;\n}" in
let ocall_null = sprintf "status = sgx_ocall(%d, NULL);\n" idx in
let ocall_with_ms = sprintf "status = sgx_ocall(%d, %s);\n"
idx ms_struct_val in
let update_retval = sprintf "if (%s) *%s = %s;"
retval_name retval_name (mk_parm_accessor retval_name) in
let func_body = ref [] in
if (is_naked_func fd) && (propagate_errno = false) then
sprintf "%s\t%s\t%s%s" func_open local_vars ocall_null func_close
else
begin
func_body := local_vars :: !func_body;
func_body := ocalloc_ms_struct:: !func_body;
List.iter (fun pd -> func_body := tproxy_fill_ms_field pd :: !func_body) fd.Ast.plist;
func_body := ocall_with_ms :: !func_body;
if fd.Ast.rtype <> Ast.Void then func_body := update_retval :: !func_body;
List.fold_left (fun acc s -> acc ^ "\t" ^ s ^ "\n") func_open (List.rev !func_body) ^ func_close
end
(* It generates OCALL table and the untrusted proxy to setup OCALL table. *)
let gen_ocall_table (ec: enclave_content) =
let func_proto_ubridge = List.map (fun (uf: Ast.untrusted_func) ->
let fd : Ast.func_decl = uf.Ast.uf_fdecl in
mk_ubridge_name ec.file_shortnm fd.Ast.fname)
ec.ufunc_decls in
let nr_ocall = List.length ec.ufunc_decls in
let ocall_table_name = mk_ocall_table_name ec.enclave_name in
let ocall_table =
let ocall_members =
List.fold_left
(fun acc proto -> acc ^ "\t\t(void*)" ^ proto ^ ",\n") "" func_proto_ubridge
in "\t{\n" ^ ocall_members ^ "\t}\n"
in
sprintf "static const struct {\n\
\tsize_t nr_ocall;\n\
\tvoid * table[%d];\n\
} %s = {\n\
\t%d,\n\
%s};\n" (max nr_ocall 1)
ocall_table_name
nr_ocall
(if nr_ocall <> 0 then ocall_table else "\t{ NULL },\n")
(* It generates untrusted code to be saved in a `.c' file. *)
let gen_untrusted_source (ec: enclave_content) =
let code_fname = get_usource_name ec.file_shortnm in
let include_hd = "#include \"" ^ get_uheader_short_name ec.file_shortnm ^ "\"\n" in
let include_errno = "#include <errno.h>\n" in
let uproxy_list =
List.map2 (fun fd ecall_idx -> gen_func_uproxy fd ecall_idx ec)
(tf_list_to_fd_list ec.tfunc_decls)
(Util.mk_seq 0 (List.length ec.tfunc_decls - 1))
in
let ubridge_list =
List.map (fun fd -> gen_func_ubridge ec.file_shortnm fd)
(ec.ufunc_decls) in
let out_chan = open_out code_fname in
output_string out_chan (include_hd ^ include_errno ^ "\n");
ms_writer out_chan ec;
List.iter (fun s -> output_string out_chan (s ^ "\n")) ubridge_list;
output_string out_chan (gen_ocall_table ec);
List.iter (fun s -> output_string out_chan (s ^ "\n")) uproxy_list;
close_out out_chan
(* It generates trusted code to be saved in a `.c' file. *)
let gen_trusted_source (ec: enclave_content) =
let code_fname = get_tsource_name ec.file_shortnm in
let include_hd = "#include \"" ^ get_theader_short_name ec.file_shortnm ^ "\"\n\n\
#include \"sgx_trts.h\" /* for sgx_ocalloc, sgx_is_outside_enclave */\n\n\
#include <errno.h>\n\
#include <string.h> /* for memcpy etc */\n\
#include <stdlib.h> /* for malloc/free etc */\n\
\n\
#define CHECK_REF_POINTER(ptr, siz) do {\t\\\n\
\tif (!(ptr) || ! sgx_is_outside_enclave((ptr), (siz)))\t\\\n\
\t\treturn SGX_ERROR_INVALID_PARAMETER;\\\n\
} while (0)\n\
\n\
#define CHECK_UNIQUE_POINTER(ptr, siz) do {\t\\\n\
\tif ((ptr) && ! sgx_is_outside_enclave((ptr), (siz)))\t\\\n\
\t\treturn SGX_ERROR_INVALID_PARAMETER;\\\n\
} while (0)\n\
\n" in
let trusted_fds = tf_list_to_fd_list ec.tfunc_decls in
let tbridge_list =
let dummy_var = tbridge_gen_dummy_variable ec in
List.map (fun tfd -> gen_func_tbridge tfd dummy_var) trusted_fds in
let ecall_table = gen_ecall_table ec.tfunc_decls in
let entry_table = gen_entry_table ec in
let tproxy_list = List.map2
(fun fd idx -> gen_func_tproxy fd idx)
(ec.ufunc_decls)
(Util.mk_seq 0 (List.length ec.ufunc_decls - 1)) in
let out_chan = open_out code_fname in
output_string out_chan (include_hd ^ "\n");
ms_writer out_chan ec;
List.iter (fun s -> output_string out_chan (s ^ "\n")) tbridge_list;
output_string out_chan (ecall_table ^ "\n");
output_string out_chan (entry_table ^ "\n");
output_string out_chan "\n";
List.iter (fun s -> output_string out_chan (s ^ "\n")) tproxy_list;
close_out out_chan
(* We use a stack to keep record of imported files.
*
* A file will be pushed to the stack before we parsing it,
* and we will pop the stack after each `parse_import_file'.
*)
let already_read = SimpleStack.create ()
let save_file fullpath =
if SimpleStack.mem fullpath already_read
then failwithf "detected circled import for `%s'" fullpath
else SimpleStack.push fullpath already_read
(* The entry point of the Edger8r parser front-end.
* ------------------------------------------------
*)
let start_parsing (fname: string) : Ast.enclave =
let set_initial_pos lexbuf filename =
lexbuf.Lexing.lex_curr_p <- {
lexbuf.Lexing.lex_curr_p with Lexing.pos_fname = fname;
}
in
try
let fullpath = Util.get_file_path fname in
let preprocessed =
save_file fullpath; Preprocessor.processor_macro(fullpath) in
let lexbuf =
match preprocessed with
| None ->
let chan = open_in fullpath in
Lexing.from_channel chan
| Some(preprocessed_string) -> Lexing.from_string preprocessed_string
in
try
set_initial_pos lexbuf fname;
let e : Ast.enclave = Parser.start_parsing Lexer.tokenize lexbuf in
let short_name = Util.get_short_name fname in
if short_name = ""
then (eprintf "error: %s: file short name is empty\n" fname; exit 1;)
else
let res = { e with Ast.ename = short_name } in
if Util.is_c_identifier short_name then res
else (eprintf "warning: %s: file short name `%s' is not a valid C identifier\n" fname short_name; res)
with exn ->
begin match exn with
| Parsing.Parse_error ->
let curr = lexbuf.Lexing.lex_curr_p in
let line = curr.Lexing.pos_lnum in
let cnum = curr.Lexing.pos_cnum - curr.Lexing.pos_bol in
let tok = Lexing.lexeme lexbuf in
failwithf "%s:%d:%d: unexpected token: %s\n" fname line cnum tok
| _ -> raise exn
end
with Sys_error s -> failwithf "%s\n" s
(* Check duplicated ECALL/OCALL names.
*
* This is a pretty simple implementation - to improve it, the
* location information of each token should be carried to AST.
*)
let check_duplication (ec: enclave_content) =
let dict = Hashtbl.create 10 in
let trusted_fds = tf_list_to_fd_list ec.tfunc_decls in
let untrusted_fds = uf_list_to_fd_list ec.ufunc_decls in
let check_and_add fname =
if Hashtbl.mem dict fname then
failwithf "Multiple definition of function \"%s\" detected." fname
else
Hashtbl.add dict fname true
in
List.iter (fun (fd: Ast.func_decl) ->
check_and_add fd.Ast.fname) (trusted_fds @ untrusted_fds)
(* For each untrusted functions, check that allowed ECALL does exist. *)
let check_allow_list (ec: enclave_content) =
let trusted_func_names = get_trusted_func_names ec in
let do_check_allow_list fname allowed_ecalls =
List.iter (fun trusted_func ->
if List.exists (fun x -> x = trusted_func) trusted_func_names
then ()
else
failwithf "\"%s\" declared to allow unknown function \"%s\"."
fname trusted_func) allowed_ecalls
in
List.iter (fun (uf: Ast.untrusted_func) ->
let fd = uf.Ast.uf_fdecl in
let allowed_ecalls = uf.Ast.uf_allow_list in
do_check_allow_list fd.Ast.fname allowed_ecalls) ec.ufunc_decls
(* Report private ECALL not used in any "allow(...)" expression. *)
let report_orphaned_priv_ecall (ec: enclave_content) =
let priv_ecall_names = get_priv_ecall_names ec.tfunc_decls in
let allowed_names = get_allowed_names ec.ufunc_decls in
let check_ecall n = if List.mem n allowed_names then ()
else eprintf "warning: private ECALL `%s' is not used by any OCALL\n" n
in
List.iter check_ecall priv_ecall_names
(* Check that there is at least one public ECALL function. *)
let check_priv_funcs (ec: enclave_content) =
let priv_bits = tf_list_to_priv_list ec.tfunc_decls in
if List.for_all (fun is_priv -> is_priv) priv_bits
then failwithf "the enclave `%s' contains no public root ECALL.\n" ec.file_shortnm
else report_orphaned_priv_ecall ec
(* When generating edge-routines, it need first to check whether there
* are `import' expressions inside EDL. If so, it will parse the given
* importing file to get an `enclave_content' record, recursively.
*
* `ec' is the toplevel `enclave_content' record.
* Here, a tree reduce algorithm is used. `ec' is the root-node, each
* `import' expression is considerred as a children.
*)
let reduce_import (ec: enclave_content) =
let combine (ec1: enclave_content) (ec2: enclave_content) =
{ ec1 with
include_list = ec1.include_list @ ec2.include_list;
import_exprs = [];
comp_defs = ec1.comp_defs @ ec2.comp_defs;
tfunc_decls = ec1.tfunc_decls @ ec2.tfunc_decls;
ufunc_decls = ec1.ufunc_decls @ ec2.ufunc_decls; }
in
let parse_import_file fname =
let ec = parse_enclave_ast (start_parsing fname)
in
match ec.import_exprs with
[] -> (SimpleStack.pop already_read |> ignore; ec )
| _ -> ec
in
let check_funs funcs (ec: enclave_content) =
(* Check whether `funcs' are listed in `ec'. It returns a
production (x, y), where:
x - functions not listed in `ec';
y - a new `ec' that contains functions from `funcs' listed in `ec'.
*)
let enclave_funcs =
let trusted_func_names = get_trusted_func_names ec in
let untrusted_func_names = get_untrusted_func_names ec in
trusted_func_names @ untrusted_func_names
in
let in_ec_def name = List.exists (fun x -> x = name) enclave_funcs in
let in_import_list name = List.exists (fun x -> x = name) funcs in
let x = List.filter (fun name -> not (in_ec_def name)) funcs in
let y =
{ empty_ec with
tfunc_decls = List.filter (fun tf ->
in_import_list (get_tf_fname tf)) ec.tfunc_decls;
ufunc_decls = List.filter (fun uf ->
in_import_list (get_uf_fname uf)) ec.ufunc_decls; }
in (x, y)
in
(* Import functions listed in `funcs' from `importee'. *)
let rec import_funcs (funcs: string list) (importee: enclave_content) =
(* A `*' means importing all the functions. *)
if List.exists (fun x -> x = "*") funcs
then
List.fold_left (fun acc (ipd: Ast.import_decl) ->
let next_ec = parse_import_file ipd.Ast.mname
in combine acc (import_funcs ipd.Ast.flist next_ec)) importee importee.import_exprs
else
let (x, y) = check_funs funcs importee
in
if x = [] then y (* Resolved all importings *)
else
match importee.import_exprs with
[] -> failwithf "import failed - functions `%s' not found" (List.hd x)
| ex -> List.fold_left (fun acc (ipd: Ast.import_decl) ->
let next_ec = parse_import_file ipd.Ast.mname
in combine acc (import_funcs x next_ec)) y ex
in
import_funcs ["*"] ec
(* Generate the Enclave code. *)
let gen_enclave_code (e: Ast.enclave) (ep: edger8r_params) =
let ec = reduce_import (parse_enclave_ast e) in
g_use_prefix := ep.use_prefix;
g_untrusted_dir := ep.untrusted_dir;
g_trusted_dir := ep.trusted_dir;
create_dir ep.untrusted_dir;
create_dir ep.trusted_dir;
check_duplication ec;
check_allow_list ec;
(if not ep.header_only then check_priv_funcs ec);
(if ep.gen_untrusted then (gen_untrusted_header ec; if not ep.header_only then gen_untrusted_source ec));
(if ep.gen_trusted then (gen_trusted_header ec; if not ep.header_only then gen_trusted_source ec))