Skip to content

Commit 7fbdd03

Browse files
committed
combine and simplify
1 parent cb2c52a commit 7fbdd03

File tree

8 files changed

+402
-157
lines changed

8 files changed

+402
-157
lines changed

snippets/ai/config.ts

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
11
import { AiProvider } from './providers/ai-provider';
22
import { EventEmitter } from 'events';
33
import { z } from 'zod';
4-
4+
import chalk from 'chalk';
5+
import { inspect } from 'util';
56

67
const configSchema = z.object({
78
provider: z.enum(['docs', 'openai', 'mistral', 'atlas', 'ollama']),
89
model: z.string(),
10+
includeSampleDocs: z.boolean(),
11+
defaultCollection: z.string().optional(),
912
});
1013

1114
const configKeys = Object.keys(configSchema.shape) as Array<keyof ConfigSchema>;
1215

13-
1416
export type ConfigSchema = z.infer<typeof configSchema>;
1517
type ConfigKeys = keyof ConfigSchema;
1618

1719
const defaults: Record<ConfigKeys, any> = {
1820
provider: process.env.MONGOSH_AI_PROVIDER ?? 'docs',
1921
model: process.env.MONGOSH_AI_MODEL ?? 'default',
22+
includeSampleDocs: process.env.MONGOSH_AI_INCLUDE_SAMPLE_DOCS ?? true,
23+
defaultCollection: process.env.MONGOSH_AI_DEFAULT_COLLECTION,
2024
};
2125

2226
export class Config extends EventEmitter<{
23-
change: [{
24-
key: ConfigKeys;
25-
value: ConfigSchema[ConfigKeys];
26-
}];
27+
change: [
28+
{
29+
key: ConfigKeys;
30+
value: ConfigSchema[ConfigKeys];
31+
},
32+
];
2733
}> {
2834
private configMap: Record<string, any> = {};
2935

@@ -68,6 +74,16 @@ export class Config extends EventEmitter<{
6874
}
6975

7076
[Symbol.for('nodejs.util.inspect.custom')]() {
71-
return this.configMap;
77+
const lines = Object.entries(configSchema.shape).map(([key, schema]) => {
78+
let type: string | undefined = undefined;
79+
if (schema._def.typeName === 'ZodEnum') {
80+
type = `${schema._def.values.join(' | ')}`;
81+
}
82+
const i = (value: any) => inspect(value, {colors: true});
83+
84+
return ` ${i(key)}: ${chalk.white(i(this.configMap[key]))},${type ? chalk.gray(` // ${type}`) : ''}`;
85+
});
86+
87+
return `{\n${lines.join('\n')}\n}`;
7288
}
7389
}

snippets/ai/decorators.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1-
export function aiCommand<T extends Function>(
1+
export interface AiCommandOptions {
2+
requiresPrompt?: boolean;
3+
}
4+
5+
export function aiCommand({
6+
requiresPrompt = true,
7+
}: AiCommandOptions = {}) {
8+
return function decorator<T extends Function>(
29
value: T,
310
// eslint-disable-next-line @typescript-eslint/no-unused-vars
411
context: ClassMethodDecoratorContext
512
): T & { isDirectShellCommand: true } {
613
const wrappedFunction = function(this: any, ...args: any[]) {
14+
if (requiresPrompt === false && args.length > 0) {
15+
throw new Error('This command does not accept any arguments');
16+
} else if (requiresPrompt && args.length === 0) {
17+
throw new Error('Please specify a prompt to run');
18+
}
719
// Combine all arguments into a single string
820
const combinedString = args.join(' ');
921
// Call the original function with the combined string
1022
return value.call(this, combinedString);
1123
} as unknown as T; // Cast the wrapped function to match the original type
1224
return Object.assign(wrappedFunction, { isDirectShellCommand: true } as const);
1325
}
26+
}

snippets/ai/helpers.ts

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ export function wrapFunction(
6868
cliContext: CliContext,
6969
instance: any,
7070
name: string | undefined,
71-
fn: Function
71+
fn: Function,
7272
) {
7373
const wrapperFn = (...args: string[]) => {
7474
return Object.assign(fn(...args), {
@@ -88,11 +88,11 @@ export function wrapFunction(
8888
export function wrapAllFunctions(cliContext: CliContext, instance: any) {
8989
const instanceState = cliContext.db._mongo._instanceState;
9090
const methods = Object.getOwnPropertyNames(
91-
Object.getPrototypeOf(instance)
91+
Object.getPrototypeOf(instance),
9292
).filter((name) => {
9393
const descriptor = Object.getOwnPropertyDescriptor(
9494
Object.getPrototypeOf(instance),
95-
name
95+
name,
9696
);
9797
return (
9898
descriptor &&
@@ -119,16 +119,26 @@ interface HelpCommand {
119119
example?: string;
120120
}
121121

122-
export function formatHelpCommands(commands: HelpCommand[], provider: string, model: string): string {
123-
const maxCmdLength = Math.max(...commands.map(c => c.cmd.length));
124-
const formattedCommands = commands.map(c => {
125-
const padding = ' '.repeat(maxCmdLength - c.cmd.length);
126-
const base = ` ${chalk.yellow(c.cmd)}${padding} ${chalk.white(c.desc)}`;
127-
return c.example ? `${base} ${chalk.gray(`| ${c.example}`)}` : base;
128-
}).join('\n');
122+
export function formatHelpCommands(
123+
commands: HelpCommand[],
124+
{
125+
provider,
126+
model,
127+
collection,
128+
}: { provider: string; model: string; collection?: string },
129+
): string {
130+
const maxCmdLength = Math.max(...commands.map((c) => c.cmd.length));
131+
const formattedCommands = commands
132+
.map((c) => {
133+
const padding = ' '.repeat(maxCmdLength - c.cmd.length);
134+
const base = ` ${chalk.yellow(c.cmd)}${padding} ${chalk.white(c.desc)}`;
135+
return c.example ? `${base} ${chalk.gray(`| ${c.example}`)}` : base;
136+
})
137+
.join('\n');
129138

130139
return `${chalk.blue.bold('AI command suite for mongosh')}
131-
${chalk.gray(`Using ${chalk.white.bold(provider)} as provider and its ${chalk.white.bold(model)} model`)}\n
132-
${formattedCommands}
140+
${chalk.gray(`Collection: ${chalk.white.bold(collection ?? 'not set')}. Set it with ai.collection("collection_name")`)}\n
141+
${formattedCommands}\n
142+
${chalk.gray(`Using ${chalk.white.bold(provider)} as provider and its ${chalk.white.bold(model)} model`)}
133143
`.trim();
134144
}

snippets/ai/index.ts

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ import { aiCommand } from './decorators';
22
import { AiProvider, EmptyAiProvider } from './providers/ai-provider';
33
import { getAtlasAiProvider } from './providers/atlas/atlas-ai-provider';
44
import { getDocsAiProvider } from './providers/docs/docs-ai-provider';
5-
import {
6-
getAiSdkProvider,
7-
models,
8-
} from './providers/generic/ai-sdk-provider';
5+
import { getAiSdkProvider, models } from './providers/generic/ai-sdk-provider';
96
import { Config, ConfigSchema } from './config';
107
import { CliContext, wrapAllFunctions, formatHelpCommands } from './helpers';
8+
import chalk from 'chalk';
119

1210
class AI {
1311
private readonly replConfig: {
@@ -37,12 +35,13 @@ class AI {
3735
this.ai = this.getProvider(event.value as ConfigSchema['provider']);
3836
break;
3937
case 'model':
40-
if (Object.keys(models).includes(event.value)) {
38+
if (Object.keys(models).includes(event.value as string)) {
4139
this.ai = getAiSdkProvider(
4240
models[this.config.get('provider') as keyof typeof models](
43-
event.value,
41+
event.value as string,
4442
),
4543
this.cliContext,
44+
this.config,
4645
);
4746
} else {
4847
throw new Error(`Invalid model: ${event.value}`);
@@ -53,7 +52,9 @@ class AI {
5352
}
5453
});
5554

56-
this.ai = this.getProvider(process.env.MONGOSH_AI_PROVIDER as ConfigSchema['provider'] | undefined);
55+
this.ai = this.getProvider(
56+
process.env.MONGOSH_AI_PROVIDER as ConfigSchema['provider'] | undefined,
57+
);
5758
wrapAllFunctions(this.cliContext, this);
5859

5960
this.setupConfig();
@@ -65,66 +66,85 @@ class AI {
6566
this.ai = this.getProvider(this.config.get('provider'));
6667
}
6768

68-
private getProvider(provider: ConfigSchema['provider'] | undefined): AiProvider {
69+
private getProvider(
70+
provider: ConfigSchema['provider'] | undefined,
71+
): AiProvider {
6972
switch (provider) {
7073
case 'docs':
71-
return getDocsAiProvider(this.cliContext);
74+
return getDocsAiProvider(this.cliContext, this.config);
7275
case 'atlas':
73-
return getAtlasAiProvider(this.cliContext);
76+
return getAtlasAiProvider(this.cliContext, this.config);
7477
case 'openai':
7578
case 'mistral':
7679
case 'ollama':
7780
const model = this.config.get('model');
7881
return getAiSdkProvider(
7982
models[provider](model === 'default' ? undefined : model),
8083
this.cliContext,
84+
this.config,
8185
);
8286
default:
83-
return new EmptyAiProvider(this.cliContext);
87+
return new EmptyAiProvider(this.cliContext, this.config);
8488
}
8589
}
8690

87-
@aiCommand
88-
async command(prompt: string) {
89-
await this.ai.command(prompt);
91+
@aiCommand()
92+
async shell(prompt: string) {
93+
await this.ai.shell(prompt);
94+
}
95+
96+
@aiCommand()
97+
async data(prompt: string) {
98+
await this.ai.data(prompt);
9099
}
91100

92-
@aiCommand
101+
@aiCommand()
93102
async query(prompt: string) {
94103
await this.ai.query(prompt);
95104
}
96105

97-
@aiCommand
106+
@aiCommand()
98107
async ask(prompt: string) {
99108
await this.ai.ask(prompt);
100109
}
101110

102-
@aiCommand
111+
@aiCommand()
103112
async aggregate(prompt: string) {
104113
await this.ai.aggregate(prompt);
105114
}
106115

107-
@aiCommand
108-
async help(...args: string[]) {
109-
const commands = [
110-
{ cmd: 'ai.ask', desc: 'ask questions', example: 'ai.ask how do I run queries in mongosh?' },
111-
{ cmd: 'ai.command', desc: 'generate any mongosh command', example: 'ai.command create a new database' },
112-
{ cmd: 'ai.query', desc: 'generate a MongoDB query', example: 'ai.query find documents where name = "Ada"' },
113-
{ cmd: 'ai.aggregate', desc: 'generate a MongoDB aggregation', example: 'ai.aggregate find documents where name = "Ada"' },
114-
{ cmd: 'ai.config', desc: 'configure the AI commands', example: 'ai.config.set("provider", "ollama")' }
115-
];
116-
117-
this.ai.respond(
118-
formatHelpCommands(
119-
commands,
120-
this.config.get('provider'),
121-
this.config.get('model')
122-
)
123-
);
116+
@aiCommand({requiresPrompt: false})
117+
async help() {
118+
this.ai.help({
119+
provider: this.config.get('provider'),
120+
model: this.config.get('model'),
121+
});
122+
}
123+
124+
@aiCommand()
125+
async clear() {
126+
this.ai.clear();
127+
}
128+
129+
@aiCommand()
130+
async collection(name: string) {
131+
await this.ai.collection(name);
132+
}
133+
134+
@aiCommand()
135+
async provider(provider: string) {
136+
this.config.set('provider', provider);
137+
this.ai.respond(`Switched to ${chalk.blue(provider)} provider`);
138+
}
139+
140+
@aiCommand()
141+
async model(model: string) {
142+
this.config.set('model', model);
143+
this.ai.respond(`Switched to ${chalk.blue(model)} model`);
124144
}
125145

126146
[Symbol.for('nodejs.util.inspect.custom')]() {
127-
this.ai.help();
147+
this.help();
128148
return '';
129149
}
130150
}

0 commit comments

Comments
 (0)