feat: Add custom headers support

This commit is contained in:
n4ze3m
2024-06-30 00:21:43 +05:30
parent 86296c96b6
commit 52f9a2953a
14 changed files with 2126 additions and 1902 deletions

View File

@@ -1,6 +1,7 @@
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"
import type { StringWithAutocomplete } from "@langchain/core/utils/types"
import { parseKeepAlive } from "./utils/ollama"
import { getCustomOllamaHeaders } from "@/services/app"
export interface OllamaInput {
embeddingOnly?: boolean
@@ -213,12 +214,14 @@ export class OllamaEmbeddingsPageAssist extends Embeddings {
"http://127.0.0.1:"
)
}
const customHeaders = await getCustomOllamaHeaders()
const response = await fetch(`${formattedBaseUrl}/api/embeddings`, {
method: "POST",
headers: {
"Content-Type": "application/json",
...this.headers
...this.headers,
...customHeaders
},
body: JSON.stringify({
prompt,

View File

@@ -1,184 +1,189 @@
import { IterableReadableStream } from "@langchain/core/utils/stream";
import type { StringWithAutocomplete } from "@langchain/core/utils/types";
import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import { IterableReadableStream } from "@langchain/core/utils/stream"
import type { StringWithAutocomplete } from "@langchain/core/utils/types"
import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"
import { getCustomOllamaHeaders } from "@/services/app"
export interface OllamaInput {
embeddingOnly?: boolean;
f16KV?: boolean;
frequencyPenalty?: number;
headers?: Record<string, string>;
keepAlive?: any;
logitsAll?: boolean;
lowVram?: boolean;
mainGpu?: number;
model?: string;
baseUrl?: string;
mirostat?: number;
mirostatEta?: number;
mirostatTau?: number;
numBatch?: number;
numCtx?: number;
numGpu?: number;
numGqa?: number;
numKeep?: number;
numPredict?: number;
numThread?: number;
penalizeNewline?: boolean;
presencePenalty?: number;
repeatLastN?: number;
repeatPenalty?: number;
ropeFrequencyBase?: number;
ropeFrequencyScale?: number;
temperature?: number;
stop?: string[];
tfsZ?: number;
topK?: number;
topP?: number;
typicalP?: number;
useMLock?: boolean;
useMMap?: boolean;
vocabOnly?: boolean;
seed?: number;
format?: StringWithAutocomplete<"json">;
embeddingOnly?: boolean
f16KV?: boolean
frequencyPenalty?: number
headers?: Record<string, string>
keepAlive?: any
logitsAll?: boolean
lowVram?: boolean
mainGpu?: number
model?: string
baseUrl?: string
mirostat?: number
mirostatEta?: number
mirostatTau?: number
numBatch?: number
numCtx?: number
numGpu?: number
numGqa?: number
numKeep?: number
numPredict?: number
numThread?: number
penalizeNewline?: boolean
presencePenalty?: number
repeatLastN?: number
repeatPenalty?: number
ropeFrequencyBase?: number
ropeFrequencyScale?: number
temperature?: number
stop?: string[]
tfsZ?: number
topK?: number
topP?: number
typicalP?: number
useMLock?: boolean
useMMap?: boolean
vocabOnly?: boolean
seed?: number
format?: StringWithAutocomplete<"json">
}
export interface OllamaRequestParams {
model: string;
format?: StringWithAutocomplete<"json">;
images?: string[];
model: string
format?: StringWithAutocomplete<"json">
images?: string[]
options: {
embedding_only?: boolean;
f16_kv?: boolean;
frequency_penalty?: number;
logits_all?: boolean;
low_vram?: boolean;
main_gpu?: number;
mirostat?: number;
mirostat_eta?: number;
mirostat_tau?: number;
num_batch?: number;
num_ctx?: number;
num_gpu?: number;
num_gqa?: number;
num_keep?: number;
num_thread?: number;
num_predict?: number;
penalize_newline?: boolean;
presence_penalty?: number;
repeat_last_n?: number;
repeat_penalty?: number;
rope_frequency_base?: number;
rope_frequency_scale?: number;
temperature?: number;
stop?: string[];
tfs_z?: number;
top_k?: number;
top_p?: number;
typical_p?: number;
use_mlock?: boolean;
use_mmap?: boolean;
vocab_only?: boolean;
};
embedding_only?: boolean
f16_kv?: boolean
frequency_penalty?: number
logits_all?: boolean
low_vram?: boolean
main_gpu?: number
mirostat?: number
mirostat_eta?: number
mirostat_tau?: number
num_batch?: number
num_ctx?: number
num_gpu?: number
num_gqa?: number
num_keep?: number
num_thread?: number
num_predict?: number
penalize_newline?: boolean
presence_penalty?: number
repeat_last_n?: number
repeat_penalty?: number
rope_frequency_base?: number
rope_frequency_scale?: number
temperature?: number
stop?: string[]
tfs_z?: number
top_k?: number
top_p?: number
typical_p?: number
use_mlock?: boolean
use_mmap?: boolean
vocab_only?: boolean
}
}
export type OllamaMessage = {
role: StringWithAutocomplete<"user" | "assistant" | "system">;
content: string;
images?: string[];
};
role: StringWithAutocomplete<"user" | "assistant" | "system">
content: string
images?: string[]
}
export interface OllamaGenerateRequestParams extends OllamaRequestParams {
prompt: string;
prompt: string
}
export interface OllamaChatRequestParams extends OllamaRequestParams {
messages: OllamaMessage[];
messages: OllamaMessage[]
}
export type BaseOllamaGenerationChunk = {
model: string;
created_at: string;
done: boolean;
total_duration?: number;
load_duration?: number;
prompt_eval_count?: number;
prompt_eval_duration?: number;
eval_count?: number;
eval_duration?: number;
};
model: string
created_at: string
done: boolean
total_duration?: number
load_duration?: number
prompt_eval_count?: number
prompt_eval_duration?: number
eval_count?: number
eval_duration?: number
}
export type OllamaGenerationChunk = BaseOllamaGenerationChunk & {
response: string;
};
response: string
}
export type OllamaChatGenerationChunk = BaseOllamaGenerationChunk & {
message: OllamaMessage;
};
message: OllamaMessage
}
export type OllamaCallOptions = BaseLanguageModelCallOptions & {
headers?: Record<string, string>;
};
headers?: Record<string, string>
}
async function* createOllamaStream(
url: string,
params: OllamaRequestParams,
options: OllamaCallOptions
) {
let formattedUrl = url;
let formattedUrl = url
if (formattedUrl.startsWith("http://localhost:")) {
// Node 18 has issues with resolving "localhost"
// See https://github.com/node-fetch/node-fetch/issues/1624
formattedUrl = formattedUrl.replace(
"http://localhost:",
"http://127.0.0.1:"
);
)
}
const customHeaders = await getCustomOllamaHeaders()
const response = await fetch(formattedUrl, {
method: "POST",
body: JSON.stringify(params),
headers: {
"Content-Type": "application/json",
...options.headers,
...customHeaders
},
signal: options.signal,
});
signal: options.signal
})
if (!response.ok) {
let error;
const responseText = await response.text();
let error
const responseText = await response.text()
try {
const json = JSON.parse(responseText);
const json = JSON.parse(responseText)
error = new Error(
`Ollama call failed with status code ${response.status}: ${json.error}`
);
)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
error = new Error(
`Ollama call failed with status code ${response.status}: ${responseText}`
);
)
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
;(error as any).response = response
throw error
}
if (!response.body) {
throw new Error(
"Could not begin Ollama stream. Please check the given URL and try again."
);
)
}
const stream = IterableReadableStream.fromReadableStream(response.body);
const stream = IterableReadableStream.fromReadableStream(response.body)
const decoder = new TextDecoder();
let extra = "";
const decoder = new TextDecoder()
let extra = ""
for await (const chunk of stream) {
const decoded = extra + decoder.decode(chunk);
const lines = decoded.split("\n");
extra = lines.pop() || "";
const decoded = extra + decoder.decode(chunk)
const lines = decoded.split("\n")
extra = lines.pop() || ""
for (const line of lines) {
try {
yield JSON.parse(line);
yield JSON.parse(line)
} catch (e) {
console.warn(`Received a non-JSON parseable chunk: ${line}`);
console.warn(`Received a non-JSON parseable chunk: ${line}`)
}
}
}
@@ -189,7 +194,7 @@ export async function* createOllamaGenerateStream(
params: OllamaGenerateRequestParams,
options: OllamaCallOptions
): AsyncGenerator<OllamaGenerationChunk> {
yield* createOllamaStream(`${baseUrl}/api/generate`, params, options);
yield* createOllamaStream(`${baseUrl}/api/generate`, params, options)
}
export async function* createOllamaChatStream(
@@ -197,13 +202,12 @@ export async function* createOllamaChatStream(
params: OllamaChatRequestParams,
options: OllamaCallOptions
): AsyncGenerator<OllamaChatGenerationChunk> {
yield* createOllamaStream(`${baseUrl}/api/chat`, params, options);
yield* createOllamaStream(`${baseUrl}/api/chat`, params, options)
}
export const parseKeepAlive = (keepAlive: any) => {
if (keepAlive === "-1") {
return -1
}
return keepAlive
}
}