diff --git a/javascript-sdk/src/sdk.ts b/javascript-sdk/src/sdk.ts index 9de9b47..f439bda 100644 --- a/javascript-sdk/src/sdk.ts +++ b/javascript-sdk/src/sdk.ts @@ -22,36 +22,51 @@ function generateCanaryWord(length = 8): string { } export default class RebuffSdk implements Rebuff { - private vectorStore: VectorStore; - private strategies: Record; + private sdkConfig: SdkConfig; + private vectorStore: VectorStore | undefined; + private strategies: Record | undefined; private defaultStrategy: string; - private constructor(strategies: Record, vectorStore: VectorStore) { - this.vectorStore = vectorStore; - this.strategies = strategies; + /** + * @deprecated Use `RebuffSdk.init` instead. + */ + constructor(config: SdkConfig) { + // We're keeping this constructor for backwards compatibility. In the future, we can make it private and + // simplify this class quite a bit. + + this.sdkConfig = config; this.defaultStrategy = "standard"; } public static async init(config: SdkConfig): Promise { - const vectorStore = await initVectorStore(config); + const sdk = new RebuffSdk(config); + sdk.vectorStore = await initVectorStore(config); + sdk.strategies = await sdk.getStrategies(); + return sdk; + } + + private async getStrategies(): Promise> { + if (this.strategies) { + return this.strategies; + } const openai = { - conn: getOpenAIInstance(config.openai.apikey), - model: config.openai.model || "gpt-3.5-turbo", + conn: getOpenAIInstance(this.sdkConfig.openai.apikey), + model: this.sdkConfig.openai.model || "gpt-3.5-turbo", }; const heuristicScoreThreshold = 0.75; const vectorScoreThreshold = 0.9; const openaiScoreThreshold = 0.9; - const strategies = { + this.strategies = { // For now, this is the only strategy. "standard": { tactics: [ new Heuristic(heuristicScoreThreshold), - new Vector(vectorScoreThreshold, vectorStore), + new Vector(vectorScoreThreshold, await this.getVectorStore()), new OpenAI(openaiScoreThreshold, openai.model, openai.conn), ] }, }; - return new RebuffSdk(strategies, vectorStore); + return this.strategies; } async detectInjection({ @@ -69,9 +84,10 @@ export default class RebuffSdk implements Rebuff { throw new RebuffError("userInput is required"); } + const strategies = await this.getStrategies(); let injectionDetected = false; let tacticResults: TacticResult[] = []; - for (const tactic of this.strategies[this.defaultStrategy].tactics) { + for (const tactic of strategies[this.defaultStrategy].tactics) { const tacticOverride = tacticOverrides.find(t => t.name === tactic.name); if (tacticOverride && tacticOverride.run === false) { continue; @@ -124,11 +140,19 @@ export default class RebuffSdk implements Rebuff { return false; } + async getVectorStore(): Promise { + if (this.vectorStore) { + return this.vectorStore; + } + this.vectorStore = await initVectorStore(this.sdkConfig); + return this.vectorStore + } + async logLeakage( input: string, metaData: Record ): Promise { - await this.vectorStore.addDocuments([new Document({ + await (await this.getVectorStore()).addDocuments([new Document({ metadata: metaData, pageContent: input, })]);