Skip to content

Commit e461f9a

Browse files
committed
Co-Authored-by: Josh Long <54473+joshlong@users.noreply.github.com>
Signed-off-by: Ilayaperumal Gopinathan <ilayaperumal.gopinathan@broadcom.com>
1 parent e59be78 commit e461f9a

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package org.springframework.ai.openai.aot;
1818

19+
import java.util.Set;
20+
21+
import org.springframework.ai.openai.api.OpenAiApi;
22+
import org.springframework.ai.openai.api.OpenAiEmbeddingDeserializer;
1923
import org.springframework.aot.hint.MemberCategory;
2024
import org.springframework.aot.hint.RuntimeHints;
2125
import org.springframework.aot.hint.RuntimeHintsRegistrar;
@@ -38,6 +42,11 @@ public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
3842
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
3943
var mcs = MemberCategory.values();
4044

45+
for (var c : Set.of(OpenAiApi.Embedding.class, OpenAiApi.EmbeddingList.class,
46+
OpenAiEmbeddingDeserializer.class)) {
47+
hints.reflection().registerType(c, MemberCategory.values());
48+
}
49+
4150
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.openai")) {
4251
hints.reflection().registerType(tr, mcs);
4352
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/aot/OpenAiRuntimeHintsTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.ai.openai.OpenAiChatOptions;
2626
import org.springframework.ai.openai.api.OpenAiApi;
2727
import org.springframework.ai.openai.api.OpenAiAudioApi;
28+
import org.springframework.ai.openai.api.OpenAiEmbeddingDeserializer;
2829
import org.springframework.ai.openai.api.OpenAiImageApi;
2930
import org.springframework.aot.hint.MemberCategory;
3031
import org.springframework.aot.hint.RuntimeHints;
@@ -312,4 +313,17 @@ void verifyJsonAnnotatedClassesContainCriticalTypes() {
312313
assertThat(containsImageApi).isTrue();
313314
}
314315

316+
@Test
317+
void verifyExplicitlyRegisteredEmbeddingClasses() {
318+
this.openAiRuntimeHints.registerHints(this.runtimeHints, null);
319+
320+
Set<TypeReference> registeredTypes = new HashSet<>();
321+
this.runtimeHints.reflection().typeHints().forEach(typeHint -> registeredTypes.add(typeHint.getType()));
322+
323+
// Verify the three classes explicitly registered in OpenAiRuntimeHints
324+
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.Embedding.class))).isTrue();
325+
assertThat(registeredTypes.contains(TypeReference.of(OpenAiApi.EmbeddingList.class))).isTrue();
326+
assertThat(registeredTypes.contains(TypeReference.of(OpenAiEmbeddingDeserializer.class))).isTrue();
327+
}
328+
315329
}

0 commit comments

Comments
 (0)