chore: Update version to 1.1.9 and add Model Settings to Ollama settings page

This commit is contained in:
n4ze3m
2024-05-23 00:39:44 +05:30
parent d2afcc6a39
commit b3a455382c
13 changed files with 1271 additions and 18 deletions

406
src/models/ChatOllama.ts Normal file
View File

@@ -0,0 +1,406 @@
import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
AIMessageChunk,
BaseMessage,
ChatMessage,
} from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import type { StringWithAutocomplete } from "@langchain/core/utils/types";
import {
createOllamaChatStream,
createOllamaGenerateStream,
type OllamaInput,
type OllamaMessage,
} from "./utils/ollama";
export interface ChatOllamaInput extends OllamaInput { }
export interface ChatOllamaCallOptions extends BaseLanguageModelCallOptions { }
export class ChatOllama
extends SimpleChatModel<ChatOllamaCallOptions>
implements ChatOllamaInput {
static lc_name() {
return "ChatOllama";
}
lc_serializable = true;
model = "llama2";
baseUrl = "http://localhost:11434";
keepAlive = "5m";
embeddingOnly?: boolean;
f16KV?: boolean;
frequencyPenalty?: number;
headers?: Record<string, string>;
logitsAll?: boolean;
lowVram?: boolean;
mainGpu?: number;
mirostat?: number;
mirostatEta?: number;
mirostatTau?: number;
numBatch?: number;
numCtx?: number;
numGpu?: number;
numGqa?: number;
numKeep?: number;
numPredict?: number;
numThread?: number;
penalizeNewline?: boolean;
presencePenalty?: number;
repeatLastN?: number;
repeatPenalty?: number;
ropeFrequencyBase?: number;
ropeFrequencyScale?: number;
temperature?: number;
stop?: string[];
tfsZ?: number;
topK?: number;
topP?: number;
typicalP?: number;
useMLock?: boolean;
useMMap?: boolean;
vocabOnly?: boolean;
seed?: number;
format?: StringWithAutocomplete<"json">;
constructor(fields: OllamaInput & BaseChatModelParams) {
super(fields);
this.model = fields.model ?? this.model;
this.baseUrl = fields.baseUrl?.endsWith("/")
? fields.baseUrl.slice(0, -1)
: fields.baseUrl ?? this.baseUrl;
this.keepAlive = fields.keepAlive ?? this.keepAlive;
this.embeddingOnly = fields.embeddingOnly;
this.f16KV = fields.f16KV;
this.frequencyPenalty = fields.frequencyPenalty;
this.headers = fields.headers;
this.logitsAll = fields.logitsAll;
this.lowVram = fields.lowVram;
this.mainGpu = fields.mainGpu;
this.mirostat = fields.mirostat;
this.mirostatEta = fields.mirostatEta;
this.mirostatTau = fields.mirostatTau;
this.numBatch = fields.numBatch;
this.numCtx = fields.numCtx;
this.numGpu = fields.numGpu;
this.numGqa = fields.numGqa;
this.numKeep = fields.numKeep;
this.numPredict = fields.numPredict;
this.numThread = fields.numThread;
this.penalizeNewline = fields.penalizeNewline;
this.presencePenalty = fields.presencePenalty;
this.repeatLastN = fields.repeatLastN;
this.repeatPenalty = fields.repeatPenalty;
this.ropeFrequencyBase = fields.ropeFrequencyBase;
this.ropeFrequencyScale = fields.ropeFrequencyScale;
this.temperature = fields.temperature;
this.stop = fields.stop;
this.tfsZ = fields.tfsZ;
this.topK = fields.topK;
this.topP = fields.topP;
this.typicalP = fields.typicalP;
this.useMLock = fields.useMLock;
this.useMMap = fields.useMMap;
this.vocabOnly = fields.vocabOnly;
this.format = fields.format;
this.seed = fields.seed;
}
protected getLsParams(options: this["ParsedCallOptions"]) {
const params = this.invocationParams(options);
return {
ls_provider: "ollama",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: this.temperature ?? undefined,
ls_stop: this.stop,
ls_max_tokens: params.options.num_predict,
};
}
_llmType() {
return "ollama";
}
/**
* A method that returns the parameters for an Ollama API call. It
* includes model and options parameters.
* @param options Optional parsed call options.
* @returns An object containing the parameters for an Ollama API call.
*/
invocationParams(options?: this["ParsedCallOptions"]) {
return {
model: this.model,
format: this.format,
keep_alive: this.keepAlive,
options: {
embedding_only: this.embeddingOnly,
f16_kv: this.f16KV,
frequency_penalty: this.frequencyPenalty,
logits_all: this.logitsAll,
low_vram: this.lowVram,
main_gpu: this.mainGpu,
mirostat: this.mirostat,
mirostat_eta: this.mirostatEta,
mirostat_tau: this.mirostatTau,
num_batch: this.numBatch,
num_ctx: this.numCtx,
num_gpu: this.numGpu,
num_gqa: this.numGqa,
num_keep: this.numKeep,
num_predict: this.numPredict,
num_thread: this.numThread,
penalize_newline: this.penalizeNewline,
presence_penalty: this.presencePenalty,
repeat_last_n: this.repeatLastN,
repeat_penalty: this.repeatPenalty,
rope_frequency_base: this.ropeFrequencyBase,
rope_frequency_scale: this.ropeFrequencyScale,
temperature: this.temperature,
stop: options?.stop ?? this.stop,
tfs_z: this.tfsZ,
top_k: this.topK,
top_p: this.topP,
typical_p: this.typicalP,
use_mlock: this.useMLock,
use_mmap: this.useMMap,
vocab_only: this.vocabOnly,
seed: this.seed,
},
};
}
_combineLLMOutput() {
return {};
}
/** @deprecated */
async *_streamResponseChunksLegacy(
input: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const stream = createOllamaGenerateStream(
this.baseUrl,
{
...this.invocationParams(options),
prompt: this._formatMessagesAsPrompt(input),
},
{
...options,
headers: this.headers,
}
);
for await (const chunk of stream) {
if (!chunk.done) {
yield new ChatGenerationChunk({
text: chunk.response,
message: new AIMessageChunk({ content: chunk.response }),
});
await runManager?.handleLLMNewToken(chunk.response ?? "");
} else {
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({ content: "" }),
generationInfo: {
model: chunk.model,
total_duration: chunk.total_duration,
load_duration: chunk.load_duration,
prompt_eval_count: chunk.prompt_eval_count,
prompt_eval_duration: chunk.prompt_eval_duration,
eval_count: chunk.eval_count,
eval_duration: chunk.eval_duration,
},
});
}
}
}
async *_streamResponseChunks(
input: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
try {
const stream = await this.caller.call(async () =>
createOllamaChatStream(
this.baseUrl,
{
...this.invocationParams(options),
messages: this._convertMessagesToOllamaMessages(input),
},
{
...options,
headers: this.headers,
}
)
);
for await (const chunk of stream) {
if (!chunk.done) {
yield new ChatGenerationChunk({
text: chunk.message.content,
message: new AIMessageChunk({ content: chunk.message.content }),
});
await runManager?.handleLLMNewToken(chunk.message.content ?? "");
} else {
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({ content: "" }),
generationInfo: {
model: chunk.model,
total_duration: chunk.total_duration,
load_duration: chunk.load_duration,
prompt_eval_count: chunk.prompt_eval_count,
prompt_eval_duration: chunk.prompt_eval_duration,
eval_count: chunk.eval_count,
eval_duration: chunk.eval_duration,
},
});
}
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.response?.status === 404) {
console.warn(
"[WARNING]: It seems you are using a legacy version of Ollama. Please upgrade to a newer version for better chat support."
);
yield* this._streamResponseChunksLegacy(input, options, runManager);
} else {
throw e;
}
}
}
protected _convertMessagesToOllamaMessages(
messages: BaseMessage[]
): OllamaMessage[] {
return messages.map((message) => {
let role;
if (message._getType() === "human") {
role = "user";
} else if (message._getType() === "ai") {
role = "assistant";
} else if (message._getType() === "system") {
role = "system";
} else {
throw new Error(
`Unsupported message type for Ollama: ${message._getType()}`
);
}
let content = "";
const images = [];
if (typeof message.content === "string") {
content = message.content;
} else {
for (const contentPart of message.content) {
if (contentPart.type === "text") {
content = `${content}\n${contentPart.text}`;
} else if (
contentPart.type === "image_url" &&
typeof contentPart.image_url === "string"
) {
const imageUrlComponents = contentPart.image_url.split(",");
// Support both data:image/jpeg;base64,<image> format as well
images.push(imageUrlComponents[1] ?? imageUrlComponents[0]);
} else {
throw new Error(
`Unsupported message content type. Must either have type "text" or type "image_url" with a string "image_url" field.`
);
}
}
}
return {
role,
content,
images,
};
});
}
/** @deprecated */
protected _formatMessagesAsPrompt(messages: BaseMessage[]): string {
const formattedMessages = messages
.map((message) => {
let messageText;
if (message._getType() === "human") {
messageText = `[INST] ${message.content} [/INST]`;
} else if (message._getType() === "ai") {
messageText = message.content;
} else if (message._getType() === "system") {
messageText = `<<SYS>> ${message.content} <</SYS>>`;
} else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(
1
)}: ${message.content}`;
} else {
console.warn(
`Unsupported message type passed to Ollama: "${message._getType()}"`
);
messageText = "";
}
return messageText;
})
.join("\n");
return formattedMessages;
}
/** @ignore */
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
const chunks = [];
for await (const chunk of this._streamResponseChunks(
messages,
options,
runManager
)) {
chunks.push(chunk.message.content);
}
return chunks.join("");
}
}

201
src/models/utils/ollama.ts Normal file
View File

@@ -0,0 +1,201 @@
import { IterableReadableStream } from "@langchain/core/utils/stream";
import type { StringWithAutocomplete } from "@langchain/core/utils/types";
import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
export interface OllamaInput {
embeddingOnly?: boolean;
f16KV?: boolean;
frequencyPenalty?: number;
headers?: Record<string, string>;
keepAlive?: string;
logitsAll?: boolean;
lowVram?: boolean;
mainGpu?: number;
model?: string;
baseUrl?: string;
mirostat?: number;
mirostatEta?: number;
mirostatTau?: number;
numBatch?: number;
numCtx?: number;
numGpu?: number;
numGqa?: number;
numKeep?: number;
numPredict?: number;
numThread?: number;
penalizeNewline?: boolean;
presencePenalty?: number;
repeatLastN?: number;
repeatPenalty?: number;
ropeFrequencyBase?: number;
ropeFrequencyScale?: number;
temperature?: number;
stop?: string[];
tfsZ?: number;
topK?: number;
topP?: number;
typicalP?: number;
useMLock?: boolean;
useMMap?: boolean;
vocabOnly?: boolean;
seed?: number;
format?: StringWithAutocomplete<"json">;
}
export interface OllamaRequestParams {
model: string;
format?: StringWithAutocomplete<"json">;
images?: string[];
options: {
embedding_only?: boolean;
f16_kv?: boolean;
frequency_penalty?: number;
logits_all?: boolean;
low_vram?: boolean;
main_gpu?: number;
mirostat?: number;
mirostat_eta?: number;
mirostat_tau?: number;
num_batch?: number;
num_ctx?: number;
num_gpu?: number;
num_gqa?: number;
num_keep?: number;
num_thread?: number;
num_predict?: number;
penalize_newline?: boolean;
presence_penalty?: number;
repeat_last_n?: number;
repeat_penalty?: number;
rope_frequency_base?: number;
rope_frequency_scale?: number;
temperature?: number;
stop?: string[];
tfs_z?: number;
top_k?: number;
top_p?: number;
typical_p?: number;
use_mlock?: boolean;
use_mmap?: boolean;
vocab_only?: boolean;
};
}
export type OllamaMessage = {
role: StringWithAutocomplete<"user" | "assistant" | "system">;
content: string;
images?: string[];
};
export interface OllamaGenerateRequestParams extends OllamaRequestParams {
prompt: string;
}
export interface OllamaChatRequestParams extends OllamaRequestParams {
messages: OllamaMessage[];
}
export type BaseOllamaGenerationChunk = {
model: string;
created_at: string;
done: boolean;
total_duration?: number;
load_duration?: number;
prompt_eval_count?: number;
prompt_eval_duration?: number;
eval_count?: number;
eval_duration?: number;
};
export type OllamaGenerationChunk = BaseOllamaGenerationChunk & {
response: string;
};
export type OllamaChatGenerationChunk = BaseOllamaGenerationChunk & {
message: OllamaMessage;
};
export type OllamaCallOptions = BaseLanguageModelCallOptions & {
headers?: Record<string, string>;
};
async function* createOllamaStream(
url: string,
params: OllamaRequestParams,
options: OllamaCallOptions
) {
let formattedUrl = url;
if (formattedUrl.startsWith("http://localhost:")) {
// Node 18 has issues with resolving "localhost"
// See https://github.com/node-fetch/node-fetch/issues/1624
formattedUrl = formattedUrl.replace(
"http://localhost:",
"http://127.0.0.1:"
);
}
const response = await fetch(formattedUrl, {
method: "POST",
body: JSON.stringify(params),
headers: {
"Content-Type": "application/json",
...options.headers,
},
signal: options.signal,
});
if (!response.ok) {
let error;
const responseText = await response.text();
try {
const json = JSON.parse(responseText);
error = new Error(
`Ollama call failed with status code ${response.status}: ${json.error}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
error = new Error(
`Ollama call failed with status code ${response.status}: ${responseText}`
);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
if (!response.body) {
throw new Error(
"Could not begin Ollama stream. Please check the given URL and try again."
);
}
const stream = IterableReadableStream.fromReadableStream(response.body);
const decoder = new TextDecoder();
let extra = "";
for await (const chunk of stream) {
const decoded = extra + decoder.decode(chunk);
const lines = decoded.split("\n");
extra = lines.pop() || "";
for (const line of lines) {
try {
yield JSON.parse(line);
} catch (e) {
console.warn(`Received a non-JSON parseable chunk: ${line}`);
}
}
}
}
export async function* createOllamaGenerateStream(
baseUrl: string,
params: OllamaGenerateRequestParams,
options: OllamaCallOptions
): AsyncGenerator<OllamaGenerationChunk> {
yield* createOllamaStream(`${baseUrl}/api/generate`, params, options);
}
export async function* createOllamaChatStream(
baseUrl: string,
params: OllamaChatRequestParams,
options: OllamaCallOptions
): AsyncGenerator<OllamaChatGenerationChunk> {
yield* createOllamaStream(`${baseUrl}/api/chat`, params, options);
}