Skip to content

Commit cb2c52a

Browse files
committed
add ollama and simplify
1 parent 2e05625 commit cb2c52a

File tree

12 files changed

+766
-290
lines changed

12 files changed

+766
-290
lines changed

snippets/ai/config.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import { AiProvider } from './providers/ai-provider';
2+
import { EventEmitter } from 'events';
3+
import { z } from 'zod';
4+
5+
6+
const configSchema = z.object({
7+
provider: z.enum(['docs', 'openai', 'mistral', 'atlas', 'ollama']),
8+
model: z.string(),
9+
});
10+
11+
const configKeys = Object.keys(configSchema.shape) as Array<keyof ConfigSchema>;
12+
13+
14+
export type ConfigSchema = z.infer<typeof configSchema>;
15+
type ConfigKeys = keyof ConfigSchema;
16+
17+
const defaults: Record<ConfigKeys, any> = {
18+
provider: process.env.MONGOSH_AI_PROVIDER ?? 'docs',
19+
model: process.env.MONGOSH_AI_MODEL ?? 'default',
20+
};
21+
22+
export class Config extends EventEmitter<{
23+
change: [{
24+
key: ConfigKeys;
25+
value: ConfigSchema[ConfigKeys];
26+
}];
27+
}> {
28+
private configMap: Record<string, any> = {};
29+
30+
constructor(
31+
private readonly replConfig: {
32+
set: (key: string, value: any) => Promise<void>;
33+
get: <T>(key: string) => Promise<T>;
34+
},
35+
) {
36+
super();
37+
}
38+
39+
async setup(): Promise<void> {
40+
const keys = Object.keys(configSchema.shape) as Array<keyof ConfigSchema>;
41+
for (const key of keys) {
42+
this.configMap[key] = (await this.replConfig.get(key)) ?? defaults[key];
43+
}
44+
}
45+
46+
get<K extends keyof ConfigSchema>(key: K): ConfigSchema[K] {
47+
this.assertKey(key);
48+
return this.configMap[key];
49+
}
50+
51+
assertKey(key: string): asserts key is ConfigKeys {
52+
if (!configKeys.includes(key as ConfigKeys)) {
53+
throw new Error(
54+
`Invalid config key: ${key}. Valid keys are: ${configKeys.join(', ')}.`,
55+
);
56+
}
57+
}
58+
59+
async set(key: ConfigKeys, value: any): Promise<void> {
60+
this.assertKey(key);
61+
62+
// Validate the value based on the key
63+
value = configSchema.shape[key].parse(value);
64+
65+
await this.replConfig.set(key, value);
66+
this.configMap[key] = value;
67+
this.emit('change', { key, value });
68+
}
69+
70+
[Symbol.for('nodejs.util.inspect.custom')]() {
71+
return this.configMap;
72+
}
73+
}

snippets/ai/helpers.ts

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,88 @@ export class LoadingAnimation {
4747
}
4848
}
4949
}
50+
51+
export interface CliContext {
52+
db: {
53+
_mongo: {
54+
_instanceState: {
55+
evaluationListener: {
56+
setConfig: (key: string, value: any) => Promise<void>;
57+
getConfig: <T>(key: string) => Promise<T>;
58+
};
59+
registerPlugin: (plugin: any) => void;
60+
shellApi: Record<string, any>;
61+
context: Record<string, any>;
62+
};
63+
};
64+
};
65+
}
66+
67+
export function wrapFunction(
68+
cliContext: CliContext,
69+
instance: any,
70+
name: string | undefined,
71+
fn: Function
72+
) {
73+
const wrapperFn = (...args: string[]) => {
74+
return Object.assign(fn(...args), {
75+
[Symbol.for('@@mongosh.syntheticPromise')]: true,
76+
});
77+
};
78+
wrapperFn.isDirectShellCommand = true;
79+
wrapperFn.returnsPromise = true;
80+
81+
const instanceState = cliContext.db._mongo._instanceState;
82+
83+
instanceState.shellApi[name ? `ai.${name}` : 'ai'] = instanceState.context[
84+
name ? `ai.${name}` : 'ai'
85+
] = wrapperFn;
86+
}
87+
88+
export function wrapAllFunctions(cliContext: CliContext, instance: any) {
89+
const instanceState = cliContext.db._mongo._instanceState;
90+
const methods = Object.getOwnPropertyNames(
91+
Object.getPrototypeOf(instance)
92+
).filter((name) => {
93+
const descriptor = Object.getOwnPropertyDescriptor(
94+
Object.getPrototypeOf(instance),
95+
name
96+
);
97+
return (
98+
descriptor &&
99+
typeof descriptor.value === 'function' &&
100+
name !== 'constructor'
101+
);
102+
});
103+
104+
// for all methods, wrap them with the wrapFunction method
105+
for (const methodName of methods) {
106+
const method = instance[methodName];
107+
if (typeof method === 'function' && method.isDirectShellCommand) {
108+
wrapFunction(cliContext, instance, methodName, method.bind(instance));
109+
}
110+
}
111+
instanceState.registerPlugin(instance);
112+
113+
wrapFunction(cliContext, instance, undefined, instance.help.bind(instance));
114+
}
115+
116+
interface HelpCommand {
117+
cmd: string;
118+
desc: string;
119+
example?: string;
120+
}
121+
122+
export function formatHelpCommands(commands: HelpCommand[], provider: string, model: string): string {
123+
const maxCmdLength = Math.max(...commands.map(c => c.cmd.length));
124+
const formattedCommands = commands.map(c => {
125+
const padding = ' '.repeat(maxCmdLength - c.cmd.length);
126+
const base = ` ${chalk.yellow(c.cmd)}${padding} ${chalk.white(c.desc)}`;
127+
return c.example ? `${base} ${chalk.gray(`| ${c.example}`)}` : base;
128+
}).join('\n');
129+
130+
return `${chalk.blue.bold('AI command suite for mongosh')}
131+
${chalk.gray(`Using ${chalk.white.bold(provider)} as provider and its ${chalk.white.bold(model)} model`)}\n
132+
${formattedCommands}
133+
`.trim();
134+
}

snippets/ai/index.ts

Lines changed: 104 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,134 @@
11
import { aiCommand } from './decorators';
2-
import { AiProvider } from './providers/ai-provider';
2+
import { AiProvider, EmptyAiProvider } from './providers/ai-provider';
3+
import { getAtlasAiProvider } from './providers/atlas/atlas-ai-provider';
34
import { getDocsAiProvider } from './providers/docs/docs-ai-provider';
5+
import {
6+
getAiSdkProvider,
7+
models,
8+
} from './providers/generic/ai-sdk-provider';
9+
import { Config, ConfigSchema } from './config';
10+
import { CliContext, wrapAllFunctions, formatHelpCommands } from './helpers';
411

512
class AI {
6-
constructor(
7-
private readonly cliContext: any,
8-
private readonly ai: AiProvider,
9-
) {
10-
const methods = Object.getOwnPropertyNames(
11-
Object.getPrototypeOf(this),
12-
).filter((name) => {
13-
const descriptor = Object.getOwnPropertyDescriptor(
14-
Object.getPrototypeOf(this),
15-
name,
16-
);
17-
return (
18-
descriptor &&
19-
typeof descriptor.value === 'function' &&
20-
name !== 'constructor'
21-
);
22-
});
13+
private readonly replConfig: {
14+
set: (key: string, value: any) => Promise<void>;
15+
get: <T>(key: string) => Promise<T>;
16+
};
2317

24-
// for all methods, wrap them with the wrapFunction method
25-
for (const methodName of methods) {
26-
const method = (this as any)[methodName];
27-
if (typeof method === 'function' && method.isDirectShellCommand) {
28-
this.wrapFunction(methodName, method.bind(this));
29-
}
30-
}
18+
private ai: AiProvider;
19+
public config: Config;
20+
21+
constructor(private readonly cliContext: CliContext) {
3122
const instanceState = this.cliContext.db._mongo._instanceState;
32-
instanceState.registerPlugin(this);
3323

34-
this.wrapFunction(undefined, this.help.bind(this));
24+
this.replConfig = {
25+
set: (key, value) =>
26+
instanceState.evaluationListener.setConfig(`snippet_ai_${key}`, value),
27+
get: (key) =>
28+
instanceState.evaluationListener.getConfig(`snippet_ai_${key}`),
29+
};
30+
31+
this.config = new Config(this.replConfig);
32+
33+
// Set up provider change listener
34+
this.config.on('change', (event) => {
35+
switch (event.key) {
36+
case 'provider':
37+
this.ai = this.getProvider(event.value as ConfigSchema['provider']);
38+
break;
39+
case 'model':
40+
if (Object.keys(models).includes(event.value)) {
41+
this.ai = getAiSdkProvider(
42+
models[this.config.get('provider') as keyof typeof models](
43+
event.value,
44+
),
45+
this.cliContext,
46+
);
47+
} else {
48+
throw new Error(`Invalid model: ${event.value}`);
49+
}
50+
break;
51+
default:
52+
break;
53+
}
54+
});
55+
56+
this.ai = this.getProvider(process.env.MONGOSH_AI_PROVIDER as ConfigSchema['provider'] | undefined);
57+
wrapAllFunctions(this.cliContext, this);
58+
59+
this.setupConfig();
3560
}
3661

37-
private wrapFunction(name: string | undefined, fn: Function) {
38-
const wrapperFn = (...args: string[]) => {
39-
return Object.assign(fn(...args), {
40-
[Symbol.for('@@mongosh.syntheticPromise')]: true,
41-
});
42-
};
43-
wrapperFn.isDirectShellCommand = true;
44-
wrapperFn.returnsPromise = true;
62+
async setupConfig() {
63+
await this.config.setup();
4564

46-
const instanceState = this.cliContext.db._mongo._instanceState;
65+
this.ai = this.getProvider(this.config.get('provider'));
66+
}
67+
68+
private getProvider(provider: ConfigSchema['provider'] | undefined): AiProvider {
69+
switch (provider) {
70+
case 'docs':
71+
return getDocsAiProvider(this.cliContext);
72+
case 'atlas':
73+
return getAtlasAiProvider(this.cliContext);
74+
case 'openai':
75+
case 'mistral':
76+
case 'ollama':
77+
const model = this.config.get('model');
78+
return getAiSdkProvider(
79+
models[provider](model === 'default' ? undefined : model),
80+
this.cliContext,
81+
);
82+
default:
83+
return new EmptyAiProvider(this.cliContext);
84+
}
85+
}
4786

48-
instanceState.shellApi[name ? `ai.${name}` : 'ai'] = instanceState.context[
49-
name ? `ai.${name}` : 'ai'
50-
] = wrapperFn;
87+
@aiCommand
88+
async command(prompt: string) {
89+
await this.ai.command(prompt);
5190
}
5291

5392
@aiCommand
54-
async query(code: string) {
55-
return await this.ai.query(code);
93+
async query(prompt: string) {
94+
await this.ai.query(prompt);
5695
}
5796

5897
@aiCommand
59-
async ask(code: string) {
60-
return await this.ai.ask(code);
98+
async ask(prompt: string) {
99+
await this.ai.ask(prompt);
61100
}
62101

63102
@aiCommand
64-
async aggregate(code: string) {
65-
return await this.ai.aggregate(code);
103+
async aggregate(prompt: string) {
104+
await this.ai.aggregate(prompt);
66105
}
67106

68107
@aiCommand
69108
async help(...args: string[]) {
109+
const commands = [
110+
{ cmd: 'ai.ask', desc: 'ask questions', example: 'ai.ask how do I run queries in mongosh?' },
111+
{ cmd: 'ai.command', desc: 'generate any mongosh command', example: 'ai.command create a new database' },
112+
{ cmd: 'ai.query', desc: 'generate a MongoDB query', example: 'ai.query find documents where name = "Ada"' },
113+
{ cmd: 'ai.aggregate', desc: 'generate a MongoDB aggregation', example: 'ai.aggregate find documents where name = "Ada"' },
114+
{ cmd: 'ai.config', desc: 'configure the AI commands', example: 'ai.config.set("provider", "ollama")' }
115+
];
116+
117+
this.ai.respond(
118+
formatHelpCommands(
119+
commands,
120+
this.config.get('provider'),
121+
this.config.get('model')
122+
)
123+
);
124+
}
125+
126+
[Symbol.for('nodejs.util.inspect.custom')]() {
70127
this.ai.help();
128+
return '';
71129
}
72130
}
73131

74132
module.exports = (globalThis: any) => {
75-
globalThis.ai = new AI(globalThis, getDocsAiProvider(globalThis));
133+
globalThis.ai = new AI(globalThis);
76134
};

snippets/ai/logger.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import EventEmitter from "events";
22

3-
const IS_DEBUG = true;
3+
const IS_DEBUG = process.env.DEBUG === 'true';
44

55
class Logger extends EventEmitter {
66
debug(...args: unknown[]) {

0 commit comments

Comments
 (0)