Skip to content

Commit 3d0ecb3

Browse files
committed
MAB_runtime_updates
1 parent 4da1f00 commit 3d0ecb3

19 files changed

Lines changed: 131395 additions & 222 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"arms": [
3+
{
4+
"name": "original",
5+
"weight": 0.22628347156945203,
6+
"matrix": "original_devdata.json"
7+
},
8+
{
9+
"name": "filtered",
10+
"weight": 0.2710968140239559,
11+
12+
"matrix": "filtered_devdata.json"
13+
},
14+
{
15+
"name": "reordered",
16+
"weight": 0.2602214321074131,
17+
"matrix": "reordered_devdata.json"
18+
},
19+
{
20+
"name": "reordered_filtered",
21+
"weight": 0.24239828229917879,
22+
"matrix": "reordered_filtered_devdata.json"
23+
}
24+
]
25+
}

app/src/main/assets/filtered_devdata.json

Lines changed: 32265 additions & 0 deletions
Large diffs are not rendered by default.

app/src/main/assets/original_devdata.json

Lines changed: 33195 additions & 0 deletions
Large diffs are not rendered by default.

app/src/main/assets/reordered_devdata.json

Lines changed: 33195 additions & 0 deletions
Large diffs are not rendered by default.

app/src/main/assets/reordered_filtered_devdata.json

Lines changed: 32250 additions & 0 deletions
Large diffs are not rendered by default.

app/src/main/assets/tutors/activity_selector/arm-weights.json

Lines changed: 0 additions & 39 deletions
This file was deleted.

app/src/main/java/cmu/xprize/robotutor/RoboTutor.java

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
import java.util.Date;
6060
import java.util.Locale;
6161
import java.util.Objects;
62+
import java.util.List;
63+
6264

6365
import cmu.xprize.comp_intervention.data.CInterventionStudentData;
6466
import cmu.xprize.comp_intervention.CInterventionTimes;
@@ -96,7 +98,7 @@
9698
import cmu.xprize.util.TCONST;
9799
import cmu.xprize.util.TTSsynthesizer;
98100
import edu.cmu.xprize.listener.ListenerBase;
99-
101+
import cmu.xprize.robotutor.tutorengine.util.Arm;
100102
import static cmu.xprize.comp_logging.PerformanceLogItem.MATRIX_TYPE.LITERACY_MATRIX;
101103
import static cmu.xprize.comp_logging.PerformanceLogItem.MATRIX_TYPE.MATH_MATRIX;
102104
import static cmu.xprize.comp_logging.PerformanceLogItem.MATRIX_TYPE.SONGS_MATRIX;
@@ -137,7 +139,7 @@ public class RoboTutor extends Activity implements IReadyListener, IRoboTutor, H
137139
private static final boolean QUICK_DEBUG_CONFIG = false;
138140
private static final ConfigurationItems QUICK_DEBUG_CONFIG_OPTION = ConfigurationQuickOptions.DEBUG_EN;
139141

140-
public static final String MATRIX_FILE = "dev_data.open.json";
142+
public static final String MATRIX_FILE = "dev_data.open.json";
141143
public static final String ARM_WEIGHTS_FILE = "arm-weights.json";
142144

143145
private static final String LOG_SEQUENCE_ID = "LOG_SEQUENCE_ID";
@@ -211,6 +213,8 @@ public class RoboTutor extends Activity implements IReadyListener, IRoboTutor, H
211213

212214
//Declare armName
213215
private String armName = "default_arm";
216+
private float armWeight = 0;
217+
public static String sArmName = null;
214218

215219
@Override
216220
protected void onCreate(Bundle savedInstanceState) {
@@ -356,35 +360,53 @@ private void initializeAndStartLogs() {
356360
String initTime = new SimpleDateFormat("yyyy.MM.dd.HH.mm.ss", Locale.US).format(calendar.getTime());
357361
SEQUENCE_ID_STRING = String.format(Locale.US, "%06d", getNextLogSequenceId());
358362
// NOTE: Need to include the configuration name when that is fully merged
359-
String logFilename = "RoboTutor_" + armName +
360-
Configuration.configVersion(this) + "_" + BuildConfig.VERSION_NAME + "_" + SEQUENCE_ID_STRING +
361-
"_" + initTime + "_" + Build.SERIAL;
362-
363+
String logFilename = "RoboTutor_" + Build.SERIAL + "_" + SEQUENCE_ID_STRING + "_" +
364+
Configuration.configVersion(this) + "_" + BuildConfig.VERSION_NAME + "_" + armName + "_" + armWeight +
365+
"_" + initTime;
363366
Log.w("LOG_DEBUG", "Beginning new session with LOG_FILENAME = " + logFilename);
364367

365368
logManager = CLogManager.getInstance();
366369
logManager.transferHotLogs(hotLogPath, readyLogPath);
367370
logManager.transferHotLogs(hotLogPathPerf, readyLogPathPerf);
368371

369-
logManager.startLogging(hotLogPath, logFilename);
372+
logManager.startLoggingWithDynamicFilename(hotLogPath, logFilename);
370373
CErrorManager.setLogManager(logManager);
371374

372375
perfLogManager = CPerfLogManager.getInstance();
373-
perfLogManager.startLogging(hotLogPathPerf, "PERF_" + logFilename);
374-
375-
CInterventionLogManager.getInstance().startLogging(interventionLogPath,
376+
perfLogManager.startLoggingWithDynamicFilename(hotLogPathPerf, "PERF_" + logFilename);
377+
CInterventionLogManager.getInstance().startLoggingWithDynamicFilename(interventionLogPath,
376378
"INT_" + logFilename);
377379

380+
ConfigurationItems config = new ConfigurationItems();
381+
// Use MAB
382+
if (config.use_MAB) {
383+
armName = MABHandler.getArm(ARM_WEIGHTS_FILE, null);
384+
List<Arm> arms = MABHandler.getarms(ARM_WEIGHTS_FILE, null);
385+
armWeight = MABHandler.getArmWeight(armName, arms);
386+
String matrixName = MABHandler.getMatrixName(armName, arms);
387+
// Log the arm details using Log.w for visibility
388+
Log.w(TAG, "Selected Arm: " + armName + ", Arm Weight: " + armWeight + ", Matrix Name: " + matrixName);
389+
} else {
390+
armName = "default_arm"; // Set back to default
391+
armWeight = 0;
392+
}
393+
394+
// Rename the Log File in-place
395+
String newLogFilename = "RoboTutor_" + Build.SERIAL + "_" + SEQUENCE_ID_STRING + "_" +
396+
Configuration.configVersion(this) + "_" + BuildConfig.VERSION_NAME + "_" + armName + "_" + armWeight +
397+
"_" + initTime;
398+
399+
logManager.updateLogFilename(newLogFilename);
400+
perfLogManager.updateLogFilename("PERF_" + newLogFilename);
401+
CInterventionLogManager.getInstance().updateLogFilename("INT_" + newLogFilename);
402+
403+
404+
405+
378406
// TODO : implement time stamps
379407
logManager.postDateTimeStamp(GRAPH_MSG, "RoboTutor:SessionStart");
380408
logManager.postEvent_I(GRAPH_MSG, "EngineVersion:" + VERSION_RT);
381409

382-
// After starting logging, select the arm name using MABHandler
383-
armName = MABHandler.getArm(ARM_WEIGHTS_FILE, null);
384-
385-
// Update the log filename with the selected arm name
386-
logFilename = logFilename.replace("default_arm", armName);
387-
Log.w(TAG, "Log filename updated: " + logFilename);
388410
}
389411

390412
/**
@@ -942,9 +964,14 @@ protected void onStart() {
942964
// Start the async task to initialize the tutor
943965
//
944966
new tutorConfigTask().execute();
945-
}
946967

968+
// Add ArmName on banner
969+
sArmName = armName;
970+
}
947971

972+
public static String getArmName() {
973+
return sArmName;
974+
}
948975
/**
949976
* requery DB Cursors here
950977
*/

app/src/main/java/cmu/xprize/robotutor/tutorengine/util/MABHandler.java

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import cmu.xprize.robotutor.tutorengine.graph.vars.IScope2;
1313
import cmu.xprize.util.IScope;
1414
import cmu.xprize.util.JSON_Helper;
15+
import cmu.xprize.comp_logging.PerformanceLogItem;
1516

1617
/**
1718
* Handler for MAB (Multi-Arm Bandit)
@@ -32,7 +33,7 @@ public class MABHandler {
3233

3334
public static String getArm(String dataSource, IScope2 scope) {
3435
List<Arm> arms = getarms(dataSource, scope);
35-
Arm selectedArm = selectArm(arms);
36+
Arm selectedArm = selectArm(arms, 0.1F);
3637
// Ensure that selected arm is not null
3738
if (selectedArm != null) {
3839
Log.d(TAG, "getArm: selected = " + selectedArm.name);
@@ -43,52 +44,73 @@ public static String getArm(String dataSource, IScope2 scope) {
4344
}
4445
}
4546

46-
// Selects an arm from a list of arms
47-
private static Arm selectArm(List<Arm> arms) {
48-
float sum = 0;
47+
public static Float getArmWeight(String armName, List<Arm> arms) {
4948
for (Arm arm : arms) {
50-
sum += arm.weight;
49+
if (arm.name.equals(armName)) {
50+
return arm.weight;
51+
}
5152
}
53+
return null;
54+
}
5255

53-
// Select random number between 0 and sum
54-
float p = getRandom(0, sum);
55-
56-
// find out where p lies
57-
float bottom = 0;
56+
public static String getMatrixName(String armName, List<Arm> arms) {
5857
for (Arm arm : arms) {
59-
float top = bottom + arm.weight;
60-
if (bottom <= p && p <= top) {
61-
return arm;
58+
if (arm.name.equals(armName)) {
59+
return arm.matrix;
6260
}
63-
bottom = top;
6461
}
6562
return null;
6663
}
6764

65+
66+
// Selects an arm from a list of arms using ε-greedy algorithm
67+
private static Arm selectArm(List<Arm> arms, float epsilon) {
68+
// Random number to decide between exploration and exploitation
69+
float randomV = getRandom(0, 1);
70+
if (randomV < epsilon) {
71+
// Exploration:choose a random arm
72+
int randomIndex = (int) getRandom(0, arms.size());
73+
Arm randomArm = arms.get(randomIndex);
74+
return randomArm;
75+
} else {
76+
// Exploitation:choose the arm with highest weight
77+
Arm bestArm = null;
78+
float maxWeight = Float.NEGATIVE_INFINITY;
79+
for (Arm arm : arms) {
80+
if (arm.weight > maxWeight) {
81+
maxWeight = arm.weight;
82+
bestArm = arm;
83+
}
84+
}
85+
return bestArm;
86+
}
87+
}
88+
89+
6890
// Returns random number between [min, max]
6991
private static float getRandom(float min, float max) {
7092
return (float) (min + Math.random() * (max - min));
7193
}
7294

7395

74-
private static List<Arm> getarms(String dataSource, IScope2 scope) {
96+
public static List<Arm> getarms(String dataSource, IScope2 scope) {
7597
String jsonData = JSON_Helper.cacheData(dataSource);
7698
List<Arm> arms = new ArrayList<>();
7799
try {
78100
JSONObject rootObject = new JSONObject(jsonData);
79101
JSONArray rootArray = rootObject.getJSONArray(KEY_ARRAY);
80102
arms = parseArray(rootArray, scope);
81-
103+
82104
// Adding logging to print arms
83105
for (Arm arm : arms) {
84-
Log.d(TAG, "Arm: " + arm.name + ", Weight: " + arm.weight + ", Matrix Path" + arm.matrix);
106+
Log.d(TAG, "Arm: " + arm.name + ", Weight: " + arm.weight + ", Matrix Path" + arm.matrix);
85107
}
86-
108+
87109
} catch (Exception e) {
88110
Log.e(TAG, "Error in getarms: " + e.getMessage());
89111
}
90112
return arms;
91-
}
113+
}
92114

93115

94116
private static List<Arm> parseArray(JSONArray array, IScope2 scope) throws JSONException {
@@ -102,4 +124,4 @@ private static List<Arm> parseArray(JSONArray array, IScope2 scope) throws JSONE
102124
}
103125

104126

105-
}
127+
}

app/src/main/java/cmu/xprize/robotutor/tutorengine/widgets/core/TBanner.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,12 @@ public void setVersionID(String versionID) {
164164

165165
//mTutor_Ver += versionID; // not sure why this versionID was used, as it seems to have no significance
166166
mTutor_Ver += "v" + Configuration.configVersion(getContext());
167-
167+
// Retrieve the arm name from RoboTutor
168+
String armName = RoboTutor.getArmName();
169+
if (armName != null) {
170+
// Add a new line, a space, or any formatting you like
171+
mTutor_Ver += "\n" + armName;
172+
}
168173
mVersion.setText(mTutor_Ver);
169174
}
170175
else {

comp_logging/src/main/java/cmu/xprize/comp_logging/CLogManager.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ private CLogManager() {
3535
}
3636

3737

38-
38+
@Override
39+
public void startLoggingWithDynamicFilename(String logPath, String logFilename) {
40+
super.startLoggingWithDynamicFilename(logPath, logFilename);
41+
}
42+
@Override
43+
public void updateLogFilename(String newFilename) {
44+
super.updateLogFilename(newFilename);
45+
}
3946

4047
}

0 commit comments

Comments
 (0)