Skip to content

Commit

Permalink
Updated langchain and fixed bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagrippa committed Mar 1, 2024
1 parent dbdf8e8 commit e65d914
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 134 deletions.
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "ai-tagger",
"name": "AI Tagger",
"version": "1.0.1",
"version": "1.0.2",
"minAppVersion": "0.15.0",
"description": "Analyze and tag your document with one click for efficient note organization using AI. OpenAI API key required",
"author": "Luca Grippa",
Expand Down
44 changes: 22 additions & 22 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

97 changes: 62 additions & 35 deletions src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { getAllTags, Notice } from 'obsidian';
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ChatOpenAI } from "@langchain/openai";
import { JsonOutputFunctionsParser } from "langchain/output_parsers";
import { JsonOutputKeyToolsParser, JsonOutputKeyToolsParserParams } from "@langchain/core/output_parsers/openai_tools";
import { Runnable } from '@langchain/core/runnables';
import {
ChatPromptTemplate,
Expand Down Expand Up @@ -35,11 +35,17 @@ export class LLM {

constructor(modelName: string, openAIApiKey: string) {
this.modelName = modelName;
const prompt = this.getPrompt();
const prompt: ChatPromptTemplate = this.getPrompt();
const functionCallingModel = this.getModel(modelName, openAIApiKey);

const outputParser = new JsonOutputFunctionsParser();
this.chain = prompt.pipe(functionCallingModel).pipe(outputParser);
// const outputParserParams = JsonOutputKeyToolsParserParams()
// const outputParser = new JsonOutputKeyToolsParser(keyName=document_tagger, returnSingle=true);
const outputParser = new JsonOutputKeyToolsParser({
keyName: "document_tagger",
returnSingle: true,
});

this.chain = prompt.pipe(functionCallingModel!).pipe(outputParser);
}

getPrompt() {
Expand All @@ -54,13 +60,17 @@ export class LLM {

const humanMessage = "DOCUMENT: \n ```{document}``` \n TAGS: \n"

const prompt = new ChatPromptTemplate({
promptMessages: [
SystemMessagePromptTemplate.fromTemplate(systemMessage),
HumanMessagePromptTemplate.fromTemplate(humanMessage),
],
inputVariables: ["tagsString", "document"],
});
const prompt = ChatPromptTemplate.fromMessages([
SystemMessagePromptTemplate.fromTemplate(systemMessage),
HumanMessagePromptTemplate.fromTemplate(humanMessage),
]);
// const prompt = new ChatPromptTemplate({
// promptMessages: [
// SystemMessagePromptTemplate.fromTemplate(systemMessage),
// HumanMessagePromptTemplate.fromTemplate(humanMessage),
// ],
// inputVariables: ["tagsString", "document"],
// });

return prompt;
}
Expand All @@ -70,26 +80,43 @@ export class LLM {
tags: z.array(z.string()).describe("An array of tags that best describes the text using existing tags."),
newTags: z.array(z.string()).describe("An array of tags that best describes the text using new tags."),
});
const llm = new ChatOpenAI({
temperature: 0,
modelName: modelName,
openAIApiKey: openAIApiKey,
});

// Binding "function_call" below makes the model always call the specified function.
// If you want to allow the model to call functions selectively, omit it.
const functionCallingModel = llm.bind({
functions: [
{
name: "document_tagger",
description: "Should always be used to tag documents.",
parameters: zodToJsonSchema(zodSchema),
try {
const llm = new ChatOpenAI({
temperature: 0,
modelName: modelName,
openAIApiKey: openAIApiKey,
});
// Binding "function_call" below makes the model always call the specified function.
// If you want to allow the model to call functions selectively, omit it.
const llmWithTools = llm.bind({
tools: [
{
type: "function" as const,
function: {
name: "document_tagger",
description: "Should always be used to tag documents.",
parameters: zodToJsonSchema(zodSchema),
},
}
],
tool_choice: {
type: "function" as const,
function: {
name: "document_tagger",
},
},
],
function_call: { name: "document_tagger" },
});
});

return llmWithTools;
} catch (error) {
if (error.message.includes('OpenAI or Azure OpenAI API key not found at new ChatOpenAI')) {
// Notify the user about the incorrect API key
throw new Error('Incorrect API key. Please check your API key.');
}
}


return functionCallingModel;
}

formatTagsString(tags: string[], newTags: string[]) {
Expand All @@ -98,7 +125,7 @@ export class LLM {
tags.forEach((tag: string) => {
tagsString += tag + " "
});

// if there are new tags, add a separator
if (newTags.length > 0) {
tagsString += "| "
Expand Down Expand Up @@ -166,12 +193,12 @@ export class LLM {
if (error.message.includes('Incorrect API key')) {
// Notify the user about the incorrect API key
throw new Error('Incorrect API key. Please check your API key.');
// } else if (error.message.includes('Invalid Authentication')) {
// // Notify the user about the incorrect API key
// throw new Error('Incorrect API key. Please check your API key.');
// } else if (error.message.includes('You must be a member of an organization to use the API')) {
// // Notify the user about the incorrect API key
// throw new Error('Your account is not part of an organization. Contact OpenAI to get added to a new organization or ask your organization manager to invite you to an organization.');
// } else if (error.message.includes('Invalid Authentication')) {
// // Notify the user about the incorrect API key
// throw new Error('Incorrect API key. Please check your API key.');
// } else if (error.message.includes('You must be a member of an organization to use the API')) {
// // Notify the user about the incorrect API key
// throw new Error('Your account is not part of an organization. Contact OpenAI to get added to a new organization or ask your organization manager to invite you to an organization.');
} else if (error.message.includes('Rate limit reached for requests')) {
// Notify the user about the incorrect API key
throw new Error('You are sending requests too quickly. Please pace your requests or read OpenAI\'s Rate limit guide.');
Expand Down
Loading

0 comments on commit e65d914

Please sign in to comment.