Skip to content

Commit d06c2d2

Browse files
committed
closes ardoco#2 for now
added the new naive iterative optimizer changed visibility of several helper functions added new constructor to ElementStore.java to create an ElementStore as a subset of an existing one refactored some magic strings to public constants
1 parent 50cf0b9 commit d06c2d2

8 files changed

Lines changed: 291 additions & 22 deletions

File tree

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/Optimization.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
/* Licensed under MIT 2025. */
22
package edu.kit.kastel.sdq.lissa.ratlr;
33

4+
import java.io.IOException;
5+
import java.nio.file.Path;
6+
import java.util.Objects;
7+
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
411
import com.fasterxml.jackson.databind.ObjectMapper;
12+
513
import edu.kit.kastel.sdq.lissa.ratlr.artifactprovider.ArtifactProvider;
614
import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager;
715
import edu.kit.kastel.sdq.lissa.ratlr.configuration.Configuration;
816
import edu.kit.kastel.sdq.lissa.ratlr.elementstore.ElementStore;
917
import edu.kit.kastel.sdq.lissa.ratlr.embeddingcreator.EmbeddingCreator;
1018
import edu.kit.kastel.sdq.lissa.ratlr.preprocessor.Preprocessor;
1119
import edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer.AbstractPromptOptimizer;
12-
import org.slf4j.Logger;
13-
import org.slf4j.LoggerFactory;
14-
15-
import java.io.IOException;
16-
import java.nio.file.Path;
17-
import java.util.Objects;
1820

1921
/**
2022
* Represents a single prompt optimization run of the LiSSA framework.
@@ -122,7 +124,8 @@ private void setup() throws IOException {
122124
sourceStore = new ElementStore(configuration.sourceStore(), false);
123125
targetStore = new ElementStore(configuration.targetStore(), true);
124126

125-
promptOptimizer = AbstractPromptOptimizer.createOptimizer(configuration.promptOptimizer());
127+
promptOptimizer = AbstractPromptOptimizer.createOptimizer(
128+
configuration.promptOptimizer(), configuration.goldStandardConfiguration());
126129
configuration.serializeAndDestroyConfiguration();
127130
}
128131

@@ -166,7 +169,10 @@ public String run() {
166169
targetStore.setup(targetElements, targetEmbeddings);
167170

168171
logger.info("Optimizing Prompt");
169-
String result = promptOptimizer.optimize(sourceStore, targetStore, "Question: Here are two parts of software development artifacts.\n\n {source_type}: '''{source_content}'''\n\n {target_type}: '''{target_content}'''\n Are they related?\n\n Answer with 'yes' or 'no'.");
172+
String result = promptOptimizer.optimize(
173+
sourceStore,
174+
targetStore,
175+
"Question: Here are two parts of software development artifacts.\n\n {source_type}: '''{source_content}'''\n\n {target_type}: '''{target_content}'''\n Are they related?\n\n Answer with 'yes' or 'no'.");
170176
logger.info("Optimized Prompt: {}", result);
171177
return result;
172178
}

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/Statistics.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public static void generateStatistics(
168168
* @throws UncheckedIOException If there are issues reading the gold standard file
169169
*/
170170
@NotNull
171-
private static Set<TraceLink> getTraceLinksFromGoldStandard(GoldStandardConfiguration goldStandardConfiguration) {
171+
public static Set<TraceLink> getTraceLinksFromGoldStandard(GoldStandardConfiguration goldStandardConfiguration) {
172172
File groundTruth = new File(goldStandardConfiguration.path());
173173
boolean header = goldStandardConfiguration.hasHeader();
174174
logger.info("Skipping header: {}", header);

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/ChatLanguageModelProvider.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public class ChatLanguageModelProvider {
6666
*/
6767
public static final int DEFAULT_SEED = 133742243;
6868

69-
//TODO: Refactor as Module Parameter?
69+
// TODO: Refactor as Module Parameter?
7070
/**
7171
* Models that do not support temperature settings. Lower temperature values mean less creativity and variation
7272
* in the model's responses.
@@ -155,6 +155,15 @@ public int seed() {
155155
return seed;
156156
}
157157

158+
/**
159+
* Gets the platform for which the language model is configured.
160+
*
161+
* @return The platform name
162+
*/
163+
public String platform() {
164+
return platform;
165+
}
166+
158167
/**
159168
* Determines the number of threads to use based on the platform.
160169
* OpenAI and Blablador platforms use 100 threads, while others use 1.
@@ -213,7 +222,7 @@ private static OpenAiChatModel createOpenAiChatModel(String model, int seed) {
213222
// Ideal temperature for most deterministic results
214223
double temperature = 0.0;
215224
// Set temperature based on the model type as some do not support temperature values
216-
for(String modelWithoutTemperature : MODELS_WITHOUT_TEMPERATURE) {
225+
for (String modelWithoutTemperature : MODELS_WITHOUT_TEMPERATURE) {
217226
if (model.equals(modelWithoutTemperature)) {
218227
temperature = 1.0;
219228
break;

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/classifier/SimpleClassifier.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
*/
2121
public class SimpleClassifier extends Classifier {
2222

23+
public static final String PROMPT_TEMPLATE_KEY = "template";
24+
2325
/**
2426
* The default template for classification requests.
2527
* This template presents two artifacts and asks if they are related.
@@ -61,7 +63,7 @@ public class SimpleClassifier extends Classifier {
6163
public SimpleClassifier(ModuleConfiguration configuration) {
6264
super(ChatLanguageModelProvider.threads(configuration));
6365
this.provider = new ChatLanguageModelProvider(configuration);
64-
this.template = configuration.argumentAsString("template", DEFAULT_TEMPLATE);
66+
this.template = configuration.argumentAsString(PROMPT_TEMPLATE_KEY, DEFAULT_TEMPLATE);
6567
this.cache = CacheManager.getDefaultInstance()
6668
.getCache(this.getClass().getSimpleName() + "_" + provider.modelName() + "_" + provider.seed());
6769
this.llm = provider.createChatModel();

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/elementstore/ElementStore.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,36 @@ public ElementStore(ModuleConfiguration configuration, boolean similarityRetriev
9595
idToElementWithEmbedding = new HashMap<>();
9696
}
9797

98+
/**
99+
* Creates a new element store with the provided content.
100+
* This constructor is used for initializing the store with existing elements and their embeddings.
101+
*
102+
* @param content List of pairs containing elements and their embeddings
103+
* @param maxResults The maximum number of results to return in similarity search
104+
* -1 indicates source store mode (no similarity search).
105+
* Positive values indicate target store mode with a limit on results.
106+
* @throws IllegalArgumentException If maxResults is less than -1 or equals 0
107+
*/
108+
public ElementStore(List<Pair<Element, float[]>> content, int maxResults) {
109+
if (maxResults < -1 || maxResults == 0) {
110+
throw new IllegalArgumentException(
111+
"The maximum number of results must be -1 to indicate a source store or greater than 0 to indicate a target store.");
112+
}
113+
this.maxResults = maxResults;
114+
115+
elementsWithEmbedding = new ArrayList<>();
116+
idToElementWithEmbedding = new HashMap<>();
117+
List<Element> elements = new ArrayList<>();
118+
List<float[]> embeddings = new ArrayList<>();
119+
for (var pair : content) {
120+
var element = pair.first();
121+
var embedding = pair.second();
122+
elements.add(element);
123+
embeddings.add(Arrays.copyOf(embedding, embedding.length));
124+
}
125+
setup(elements, embeddings);
126+
}
127+
98128
/**
99129
* Initializes the element store with elements and their embeddings for LiSSA's processing.
100130
*
@@ -225,6 +255,10 @@ public List<Pair<Element, float[]>> getAllElements(boolean onlyCompare) {
225255
return getAllElementsIntern(onlyCompare);
226256
}
227257

258+
public List<Element> getAllElements() {
259+
return getAllElementsIntern(false).stream().map(Pair::first).toList();
260+
}
261+
228262
/**
229263
* Internal method to retrieve all elements.
230264
* Available in both source and target store modes for LiSSA's internal processing.

03_Code/LiSSA_RATLR/src/main/java/edu/kit/kastel/sdq/lissa/ratlr/promptoptimizer/AbstractPromptOptimizer.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.slf4j.Logger;
55
import org.slf4j.LoggerFactory;
66

7+
import edu.kit.kastel.sdq.lissa.ratlr.configuration.GoldStandardConfiguration;
78
import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration;
89
import edu.kit.kastel.sdq.lissa.ratlr.elementstore.ElementStore;
910

@@ -20,13 +21,15 @@ protected AbstractPromptOptimizer(int threads) {
2021
this.threads = Math.max(1, threads);
2122
}
2223

23-
public static AbstractPromptOptimizer createOptimizer(ModuleConfiguration configuration) {
24+
public static AbstractPromptOptimizer createOptimizer(
25+
ModuleConfiguration configuration, GoldStandardConfiguration goldStandard) {
2426
if (configuration == null) {
2527
return new MockOptimizer();
2628
}
2729
return switch (configuration.name().split(CONFIG_NAME_SEPARATOR)[0]) {
2830
case "mock" -> new MockOptimizer();
2931
case "simple" -> new SimpleOptimizer(configuration);
32+
case "iterative" -> new IterativeOptimizer(configuration, goldStandard);
3033
default -> throw new IllegalStateException("Unexpected value: " + configuration.name());
3134
};
3235
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/* Licensed under MIT 2025. */
2+
package edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer;
3+
4+
import static edu.kit.kastel.sdq.lissa.ratlr.Statistics.getTraceLinksFromGoldStandard;
5+
import static edu.kit.kastel.sdq.lissa.ratlr.classifier.SimpleClassifier.PROMPT_TEMPLATE_KEY;
6+
import static edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer.SimpleOptimizer.DEFAULT_OPTIMIZATION_TEMPLATE;
7+
import static edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer.SimpleOptimizer.ORIGINAL_PROMPT_KEY;
8+
import static edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer.SimpleOptimizer.PROMPT_END;
9+
import static edu.kit.kastel.sdq.lissa.ratlr.promptoptimizer.SimpleOptimizer.PROMPT_START;
10+
11+
import java.util.List;
12+
import java.util.Map;
13+
import java.util.Set;
14+
import java.util.regex.Matcher;
15+
import java.util.regex.Pattern;
16+
import java.util.stream.Collectors;
17+
18+
import edu.kit.kastel.mcse.ardoco.metrics.ClassificationMetricsCalculator;
19+
import edu.kit.kastel.sdq.lissa.ratlr.cache.Cache;
20+
import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheKey;
21+
import edu.kit.kastel.sdq.lissa.ratlr.cache.CacheManager;
22+
import edu.kit.kastel.sdq.lissa.ratlr.classifier.ChatLanguageModelProvider;
23+
import edu.kit.kastel.sdq.lissa.ratlr.classifier.ClassificationResult;
24+
import edu.kit.kastel.sdq.lissa.ratlr.classifier.Classifier;
25+
import edu.kit.kastel.sdq.lissa.ratlr.configuration.GoldStandardConfiguration;
26+
import edu.kit.kastel.sdq.lissa.ratlr.configuration.ModuleConfiguration;
27+
import edu.kit.kastel.sdq.lissa.ratlr.elementstore.ElementStore;
28+
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.Element;
29+
import edu.kit.kastel.sdq.lissa.ratlr.knowledge.TraceLink;
30+
import edu.kit.kastel.sdq.lissa.ratlr.postprocessor.TraceLinkIdPostprocessor;
31+
import edu.kit.kastel.sdq.lissa.ratlr.resultaggregator.ResultAggregator;
32+
import edu.kit.kastel.sdq.lissa.ratlr.utils.KeyGenerator;
33+
34+
import dev.langchain4j.model.chat.ChatModel;
35+
36+
public class IterativeOptimizer extends AbstractPromptOptimizer {
37+
38+
private static final double THRESHOLD_F1_SCORE = 1.0;
39+
40+
/**
41+
* The maximum number of iterations/requests for the optimization process.
42+
*/
43+
private static final int MAXIMUM_ITERATIONS = 10;
44+
/**
45+
* The size of the training data used for optimization.
46+
* This is the number of training examples provided to the language model.
47+
*/
48+
private static final int TRAINING_DATA_SIZE = 5;
49+
50+
private final Cache cache;
51+
52+
/**
53+
* Provider for the language model used in classification.
54+
*/
55+
private final ChatLanguageModelProvider provider;
56+
57+
/**
58+
* The language model instance used for classification.
59+
*/
60+
private final ChatModel llm;
61+
62+
/**
63+
* The template used for classification requests.
64+
*/
65+
private final String template;
66+
67+
private String optimizationPrompt;
68+
69+
private ResultAggregator aggregator;
70+
private TraceLinkIdPostprocessor traceLinkIdPostProcessor;
71+
private final Set<TraceLink> validTraceLinks;
72+
private ClassificationMetricsCalculator cmc;
73+
/**
74+
* Creates a new iterative optimizer with the specified configuration.
75+
*
76+
* @param configuration The module configuration containing optimizer settings
77+
*/
78+
public IterativeOptimizer(ModuleConfiguration configuration, GoldStandardConfiguration goldStandard) {
79+
super(ChatLanguageModelProvider.threads(configuration));
80+
this.provider = new ChatLanguageModelProvider(configuration);
81+
this.template = configuration.argumentAsString("optimization_template", DEFAULT_OPTIMIZATION_TEMPLATE);
82+
this.cache = CacheManager.getDefaultInstance()
83+
.getCache(this.getClass().getSimpleName() + "_" + provider.modelName() + "_" + provider.seed());
84+
this.llm = provider.createChatModel();
85+
this.validTraceLinks = getTraceLinksFromGoldStandard(goldStandard);
86+
setup();
87+
}
88+
89+
private IterativeOptimizer(
90+
int threads,
91+
Cache cache,
92+
ChatLanguageModelProvider provider,
93+
String template,
94+
Set<TraceLink> validTraceLinks) {
95+
super(threads);
96+
this.cache = cache;
97+
this.provider = provider;
98+
this.template = template;
99+
this.llm = provider.createChatModel();
100+
this.validTraceLinks = validTraceLinks;
101+
setup();
102+
}
103+
/**
104+
* TODO: Configure in actual configuration file. Decide whether they should be args or adapted
105+
* Req2Req Example:
106+
* "result_aggregator" : {
107+
* "name" : "any_connection",
108+
* "args" : {}
109+
* },
110+
* "tracelinkid_postprocessor" : {
111+
* "name" : "req2req",
112+
* "args" : {}
113+
*/
114+
private void setup() {
115+
cmc = ClassificationMetricsCalculator.getInstance();
116+
this.aggregator = ResultAggregator.createResultAggregator(new ModuleConfiguration("any_connection", Map.of()));
117+
118+
this.traceLinkIdPostProcessor =
119+
TraceLinkIdPostprocessor.createTraceLinkIdPostprocessor(new ModuleConfiguration("req2req", Map.of()));
120+
}
121+
122+
@Override
123+
public String optimize(ElementStore sourceStore, ElementStore targetStore, String prompt) {
124+
Element source = sourceStore.getAllElements(true).getFirst().first();
125+
Element target = targetStore
126+
.findSimilar(sourceStore.getAllElements(true).getFirst().second())
127+
.getFirst();
128+
optimizationPrompt =
129+
template.replace("{source_type}", source.getType()).replace("{target_type}", target.getType());
130+
ElementStore trainingSourceStore =
131+
new ElementStore(sourceStore.getAllElements(false).subList(0, TRAINING_DATA_SIZE), -1);
132+
133+
return optimizeIntern(trainingSourceStore, targetStore, prompt);
134+
}
135+
136+
private String optimizeIntern(ElementStore sourceStore, ElementStore targetStore, String prompt) {
137+
double[] f1Scores = new double[MAXIMUM_ITERATIONS];
138+
int i = 0;
139+
double f1Score;
140+
String modifiedPrompt = prompt;
141+
do {
142+
logger.debug("Iteration {}: RequestPrompt = {}", i, modifiedPrompt);
143+
f1Score = scorePrompt(sourceStore, targetStore, modifiedPrompt);
144+
logger.info("Iteration {}: F1-Score = {}", i, f1Score);
145+
f1Scores[i] = f1Score;
146+
modifiedPrompt = optimize(prompt);
147+
prompt = modifiedPrompt;
148+
i++;
149+
} while (i < MAXIMUM_ITERATIONS && f1Score < THRESHOLD_F1_SCORE);
150+
logger.info("Iterations {}: F1-Scores = {}", i, f1Scores);
151+
return prompt;
152+
}
153+
154+
/**
155+
* Optimizes the given prompt using the language model.
156+
* This method is used for a single iterative optimization step.
157+
*
158+
* @param prompt The original prompt to be optimized
159+
* @return The optimized prompt
160+
*/
161+
private String optimize(String prompt) {
162+
String request = optimizationPrompt.replace(ORIGINAL_PROMPT_KEY, prompt);
163+
164+
String key = KeyGenerator.generateKey(request);
165+
CacheKey cacheKey = new CacheKey(provider.modelName(), provider.seed(), CacheKey.Mode.CHAT, request, key);
166+
String response = cache.get(cacheKey, String.class);
167+
if (response == null) {
168+
logger.info("Optimizing ({}): {}", provider.modelName(), request);
169+
response = llm.chat(request);
170+
cache.put(cacheKey, response);
171+
}
172+
logger.debug("Response: {}", response);
173+
Pattern pattern = Pattern.compile(PROMPT_START + "(?s).*?" + PROMPT_END, Pattern.CASE_INSENSITIVE);
174+
Matcher matcher = pattern.matcher(response);
175+
if (matcher.find()) {
176+
response = matcher.group(0).strip();
177+
} else {
178+
logger.warn("No prompt found in response: {}", response);
179+
// Fallback to original prompt if no match found
180+
response = prompt;
181+
}
182+
return response;
183+
}
184+
185+
private double scorePrompt(ElementStore trainingStore, ElementStore targetStore, String prompt) {
186+
ModuleConfiguration classifierConfig = new ModuleConfiguration(
187+
"simple_" + provider.platform(), Map.of("model", provider.modelName(), PROMPT_TEMPLATE_KEY, prompt));
188+
Classifier classifier = Classifier.createClassifier(classifierConfig);
189+
List<ClassificationResult> results = classifier.classify(trainingStore, targetStore);
190+
191+
Set<TraceLink> traceLinks =
192+
aggregator.aggregate(trainingStore.getAllElements(), targetStore.getAllElements(), results);
193+
traceLinks = traceLinkIdPostProcessor.postprocess(traceLinks);
194+
List<String> traceLinkIds = trainingStore.getAllElements().stream()
195+
.map(Element::getIdentifier)
196+
.map(id -> id.substring(0, id.lastIndexOf(".")))
197+
.toList();
198+
Set<TraceLink> possibleTraceLinks = validTraceLinks.stream()
199+
.filter(tl -> traceLinkIds.contains(tl.sourceId()))
200+
.collect(Collectors.toSet());
201+
var classification = cmc.calculateMetrics(traceLinks, possibleTraceLinks, null);
202+
return classification.getF1();
203+
}
204+
205+
@Override
206+
protected AbstractPromptOptimizer copyOf(AbstractPromptOptimizer original) {
207+
return new IterativeOptimizer(threads, cache, provider, template, validTraceLinks);
208+
}
209+
}

0 commit comments

Comments
 (0)