Skip to content

Commit 18a975c

Browse files
author
DvirDukhan
committed
restored torch_c.h
1 parent cc11bc1 commit 18a975c

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

src/backends/libtorch_c/torch_c.h

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#pragma once
2+
3+
#include "dlpack/dlpack.h"
4+
5+
#ifdef __cplusplus
6+
extern "C" {
7+
#endif
8+
9+
#include "redis_ai_objects/script_struct.h"
10+
11+
typedef struct TorchFunctionInputCtx {
12+
DLManagedTensor **tensorInputs;
13+
size_t tensorCount;
14+
int32_t *intInputs;
15+
size_t intCount;
16+
float *floatInputs;
17+
size_t floatCount;
18+
RedisModuleString **stringsInputs;
19+
size_t stringCount;
20+
size_t *listSizes;
21+
size_t listCount;
22+
} TorchFunctionInputCtx;
23+
24+
/**
25+
* @brief Compiles a script string into torch compliation unit stored in a module context.
26+
*
27+
* @param script Script string.
28+
* @param device Device for the script to execute on.
29+
* @param device_id Device id for the script to execute on.
30+
* @param error Error string to be populated in case of an exception.
31+
* @return void* ModuleContext pointer.
32+
*/
33+
void *torchCompileScript(const char *script, DLDeviceType device, int64_t device_id, char **error);
34+
35+
/**
36+
* @brief Loads a model from model definition string and stores it in a module context.
37+
*
38+
* @param model Model definition string.
39+
* @param modellen Length of the string.
40+
* @param device Device for the model to execute on.
41+
* @param device_id Device id for the model to execute on.
42+
* @param error Error string to be populated in case of an exception.
43+
* @return void* ModuleContext pointer.
44+
*/
45+
void *torchLoadModel(const char *model, size_t modellen, DLDeviceType device, int64_t device_id,
46+
char **error);
47+
48+
/**
49+
* @brief Validate SCRIPTEXECUTE or LLAPI script execute inputs according to the funciton schema.
50+
*
51+
* @param schema Fuction argument types (schema).
52+
* @param nArguments Number of arguments in the function.
53+
* @param inputsCtx Function execution context containing the information about given inputs.
54+
* @param error Error string to be populated in case of an exception.
55+
* @return true If the user provided inputs from types and order that matches the schema.
56+
* @return false Otherwise.
57+
*/
58+
bool torchMatchScriptSchema(TorchScriptFunctionArgumentType *schema, size_t nArguments,
59+
TorchFunctionInputCtx *inputsCtx, char **error);
60+
61+
/**
62+
* @brief Executes a function in a script.
63+
* @note Should be called after torchMatchScriptSchema verication.
64+
* @param scriptCtx Executes a function in a script.
65+
* @param fnName Function name.
66+
* @param schema Fuction argument types (schema).
67+
* @param nArguments Number of arguments in the function.
68+
* @param inputsCtx unction execution context containing the information about given inputs.
69+
* @param outputs Array of output tensor (placeholders).
70+
* @param nOutputs Number of output tensors.
71+
* @param error Error string to be populated in case of an exception.
72+
*/
73+
void torchRunScript(void *scriptCtx, const char *fnName, TorchScriptFunctionArgumentType *schema,
74+
size_t nArguments, TorchFunctionInputCtx *inputsCtx, DLManagedTensor **outputs,
75+
long nOutputs, char **error);
76+
77+
/**
78+
* @brief Executes a model.
79+
*
80+
* @param modelCtx Model context.
81+
* @param nInputs Number of tensor inputs.
82+
* @param inputs Array of input tensors.
83+
* @param nOutputs Number of output tensors.
84+
* @param outputs Array of output tensor (placeholders).
85+
* @param error Error string to be populated in case of an exception.
86+
*/
87+
void torchRunModel(void *modelCtx, long nInputs, DLManagedTensor **inputs, long nOutputs,
88+
DLManagedTensor **outputs, char **error);
89+
90+
/**
91+
* @brief
92+
*
93+
* @param modelCtx Serilized a model into a string defintion.
94+
* @param buffer Byte array to hold the definition.
95+
* @param len Will store the length of the string.
96+
* @param error Error string to be populated in case of an exception.
97+
*/
98+
void torchSerializeModel(void *modelCtx, char **buffer, size_t *len, char **error);
99+
100+
/**
101+
* @brief Deallicate the create torch script/model object.
102+
*
103+
* @param ctx Object to free.
104+
*/
105+
void torchDeallocContext(void *ctx);
106+
107+
/**
108+
* @brief Sets the number of inter-op threads for Torch backend.
109+
*
110+
* @param num_threads Number of inter-op threads.
111+
* @param error Error string to be populated in case of an exception.
112+
*/
113+
void torchSetInterOpThreads(int num_threads, char **error);
114+
115+
/**
116+
* @brief Sets the number of intra-op threads for Torch backend.
117+
*
118+
* @param num_threads Number of intra-op threads.
119+
* @param error Error string to be populated in case of an exception.
120+
*/
121+
void torchSetIntraOpThreads(int num_threadsm, char **error);
122+
123+
/**
124+
* @brief Returns the number of inputs of a model
125+
*
126+
* @param modelCtx Model context.
127+
* @param error Error string to be populated in case of an exception.
128+
* @return size_t Number of inputs.
129+
*/
130+
size_t torchModelNumInputs(void *modelCtx, char **error);
131+
132+
/**
133+
* @brief Returns the name of the model input at index.
134+
*
135+
* @param modelCtx Model context.
136+
* @param index Input index.
137+
* @param error Error string to be populated in case of an exception.
138+
* @return const char* Input name.
139+
*/
140+
const char *torchModelInputNameAtIndex(void *modelCtx, size_t index, char **error);
141+
142+
/**
143+
* @brief Returns the number of outputs of a model
144+
*
145+
* @param modelCtx Model context.
146+
* @param error Error string to be populated in case of an exception.
147+
* @return size_t Number of outputs.
148+
*/
149+
size_t torchModelNumOutputs(void *modelCtx, char **error);
150+
151+
/**
152+
* @brief Return the number of functions in the script.
153+
*
154+
* @param scriptCtx Script context.
155+
* @return size_t number of functions.
156+
*/
157+
size_t torchScript_FunctionCount(void *scriptCtx);
158+
159+
/**
160+
* @brief Return the name of the function numbered fn_index in the script.
161+
*
162+
* @param scriptCtx Script context.
163+
* @param fn_index Function number.
164+
* @return const char* Function name.
165+
*/
166+
const char *torchScript_FunctionName(void *scriptCtx, size_t fn_index);
167+
168+
/**
169+
* @brief Return the number of arguments in the fuction numbered fn_index in the script.
170+
*
171+
* @param scriptCtx Script context.
172+
* @param fn_index Function number.
173+
* @return size_t Number of arguments.
174+
*/
175+
size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index);
176+
177+
/**
178+
* @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the
179+
* script.
180+
*
181+
* @param scriptCtx Script context.
182+
* @param fn_index Function number.
183+
* @param arg_index Argument number.
184+
* @return TorchScriptFunctionArgumentType The type of the argument in RedisAI enum format.
185+
*/
186+
TorchScriptFunctionArgumentType torchScript_FunctionArgumentype(void *scriptCtx, size_t fn_index,
187+
size_t arg_index);
188+
189+
#ifdef __cplusplus
190+
}
191+
#endif

0 commit comments

Comments
 (0)