@@ -136,7 +136,8 @@ TF_Tensor* RAI_TFTensorFromTensor(RAI_Tensor* t){
136136}
137137
138138
139- RAI_Graph * RAI_GraphCreateTF (const char * prefix , RAI_Backend backend ,
139+ RAI_Graph * RAI_GraphCreateTF (const char * prefix ,
140+ RAI_Backend backend , RAI_Device device ,
140141 const char * graphdef , size_t graphlen ) {
141142 TF_Graph * graph = TF_NewGraph ();
142143
@@ -160,9 +161,31 @@ RAI_Graph *RAI_GraphCreateTF(const char *prefix, RAI_Backend backend,
160161 TF_DeleteBuffer (buffer );
161162 TF_DeleteStatus (status );
162163
163- TF_Status * sessionStatus = TF_NewStatus ();
164+ TF_Status * optionsStatus = TF_NewStatus ();
164165
165166 TF_SessionOptions * sessionOptions = TF_NewSessionOptions ();
167+
168+ // For setting config options in session from the C API see:
169+ // https://github.com/tensorflow/tensorflow/issues/13853
170+ // import tensorflow as tf
171+ // config = tf.ConfigProto()
172+ // config.intra_op_parallelism_threads = 1
173+ // serialized = config.SerializeToString()
174+ // result = list(map(hex, serialized))
175+
176+ // TODO: complain if device is GPU and GPU not available?
177+ if (device == RAI_DEVICE_CPU ) {
178+ uint8_t config [9 ] = {0x0a , 0x07 , 0x0a , 0x03 , 0x47 , 0x50 , 0x55 , 0x10 , 0x00 };
179+ TF_SetConfig (sessionOptions , (void * )config , 9 , status );
180+ }
181+
182+ if (TF_GetCode (optionsStatus ) != TF_OK ) {
183+ // TODO: free memory
184+ return NULL ;
185+ }
186+ TF_DeleteStatus (optionsStatus );
187+
188+ TF_Status * sessionStatus = TF_NewStatus ();
166189 TF_Session * session = TF_NewSession (graph , sessionOptions , sessionStatus );
167190
168191 if (TF_GetCode (sessionStatus ) != TF_OK ) {
0 commit comments