feat: Add custom headers support
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings"
|
||||
import type { StringWithAutocomplete } from "@langchain/core/utils/types"
|
||||
import { parseKeepAlive } from "./utils/ollama"
|
||||
import { getCustomOllamaHeaders } from "@/services/app"
|
||||
|
||||
export interface OllamaInput {
|
||||
embeddingOnly?: boolean
|
||||
@@ -213,12 +214,14 @@ export class OllamaEmbeddingsPageAssist extends Embeddings {
|
||||
"http://127.0.0.1:"
|
||||
)
|
||||
}
|
||||
const customHeaders = await getCustomOllamaHeaders()
|
||||
|
||||
const response = await fetch(`${formattedBaseUrl}/api/embeddings`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...this.headers
|
||||
...this.headers,
|
||||
...customHeaders
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt,
|
||||
|
||||
@@ -1,184 +1,189 @@
|
||||
import { IterableReadableStream } from "@langchain/core/utils/stream";
|
||||
import type { StringWithAutocomplete } from "@langchain/core/utils/types";
|
||||
import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
|
||||
import { IterableReadableStream } from "@langchain/core/utils/stream"
|
||||
import type { StringWithAutocomplete } from "@langchain/core/utils/types"
|
||||
import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"
|
||||
import { getCustomOllamaHeaders } from "@/services/app"
|
||||
|
||||
export interface OllamaInput {
|
||||
embeddingOnly?: boolean;
|
||||
f16KV?: boolean;
|
||||
frequencyPenalty?: number;
|
||||
headers?: Record<string, string>;
|
||||
keepAlive?: any;
|
||||
logitsAll?: boolean;
|
||||
lowVram?: boolean;
|
||||
mainGpu?: number;
|
||||
model?: string;
|
||||
baseUrl?: string;
|
||||
mirostat?: number;
|
||||
mirostatEta?: number;
|
||||
mirostatTau?: number;
|
||||
numBatch?: number;
|
||||
numCtx?: number;
|
||||
numGpu?: number;
|
||||
numGqa?: number;
|
||||
numKeep?: number;
|
||||
numPredict?: number;
|
||||
numThread?: number;
|
||||
penalizeNewline?: boolean;
|
||||
presencePenalty?: number;
|
||||
repeatLastN?: number;
|
||||
repeatPenalty?: number;
|
||||
ropeFrequencyBase?: number;
|
||||
ropeFrequencyScale?: number;
|
||||
temperature?: number;
|
||||
stop?: string[];
|
||||
tfsZ?: number;
|
||||
topK?: number;
|
||||
topP?: number;
|
||||
typicalP?: number;
|
||||
useMLock?: boolean;
|
||||
useMMap?: boolean;
|
||||
vocabOnly?: boolean;
|
||||
seed?: number;
|
||||
format?: StringWithAutocomplete<"json">;
|
||||
embeddingOnly?: boolean
|
||||
f16KV?: boolean
|
||||
frequencyPenalty?: number
|
||||
headers?: Record<string, string>
|
||||
keepAlive?: any
|
||||
logitsAll?: boolean
|
||||
lowVram?: boolean
|
||||
mainGpu?: number
|
||||
model?: string
|
||||
baseUrl?: string
|
||||
mirostat?: number
|
||||
mirostatEta?: number
|
||||
mirostatTau?: number
|
||||
numBatch?: number
|
||||
numCtx?: number
|
||||
numGpu?: number
|
||||
numGqa?: number
|
||||
numKeep?: number
|
||||
numPredict?: number
|
||||
numThread?: number
|
||||
penalizeNewline?: boolean
|
||||
presencePenalty?: number
|
||||
repeatLastN?: number
|
||||
repeatPenalty?: number
|
||||
ropeFrequencyBase?: number
|
||||
ropeFrequencyScale?: number
|
||||
temperature?: number
|
||||
stop?: string[]
|
||||
tfsZ?: number
|
||||
topK?: number
|
||||
topP?: number
|
||||
typicalP?: number
|
||||
useMLock?: boolean
|
||||
useMMap?: boolean
|
||||
vocabOnly?: boolean
|
||||
seed?: number
|
||||
format?: StringWithAutocomplete<"json">
|
||||
}
|
||||
|
||||
export interface OllamaRequestParams {
|
||||
model: string;
|
||||
format?: StringWithAutocomplete<"json">;
|
||||
images?: string[];
|
||||
model: string
|
||||
format?: StringWithAutocomplete<"json">
|
||||
images?: string[]
|
||||
options: {
|
||||
embedding_only?: boolean;
|
||||
f16_kv?: boolean;
|
||||
frequency_penalty?: number;
|
||||
logits_all?: boolean;
|
||||
low_vram?: boolean;
|
||||
main_gpu?: number;
|
||||
mirostat?: number;
|
||||
mirostat_eta?: number;
|
||||
mirostat_tau?: number;
|
||||
num_batch?: number;
|
||||
num_ctx?: number;
|
||||
num_gpu?: number;
|
||||
num_gqa?: number;
|
||||
num_keep?: number;
|
||||
num_thread?: number;
|
||||
num_predict?: number;
|
||||
penalize_newline?: boolean;
|
||||
presence_penalty?: number;
|
||||
repeat_last_n?: number;
|
||||
repeat_penalty?: number;
|
||||
rope_frequency_base?: number;
|
||||
rope_frequency_scale?: number;
|
||||
temperature?: number;
|
||||
stop?: string[];
|
||||
tfs_z?: number;
|
||||
top_k?: number;
|
||||
top_p?: number;
|
||||
typical_p?: number;
|
||||
use_mlock?: boolean;
|
||||
use_mmap?: boolean;
|
||||
vocab_only?: boolean;
|
||||
};
|
||||
embedding_only?: boolean
|
||||
f16_kv?: boolean
|
||||
frequency_penalty?: number
|
||||
logits_all?: boolean
|
||||
low_vram?: boolean
|
||||
main_gpu?: number
|
||||
mirostat?: number
|
||||
mirostat_eta?: number
|
||||
mirostat_tau?: number
|
||||
num_batch?: number
|
||||
num_ctx?: number
|
||||
num_gpu?: number
|
||||
num_gqa?: number
|
||||
num_keep?: number
|
||||
num_thread?: number
|
||||
num_predict?: number
|
||||
penalize_newline?: boolean
|
||||
presence_penalty?: number
|
||||
repeat_last_n?: number
|
||||
repeat_penalty?: number
|
||||
rope_frequency_base?: number
|
||||
rope_frequency_scale?: number
|
||||
temperature?: number
|
||||
stop?: string[]
|
||||
tfs_z?: number
|
||||
top_k?: number
|
||||
top_p?: number
|
||||
typical_p?: number
|
||||
use_mlock?: boolean
|
||||
use_mmap?: boolean
|
||||
vocab_only?: boolean
|
||||
}
|
||||
}
|
||||
|
||||
export type OllamaMessage = {
|
||||
role: StringWithAutocomplete<"user" | "assistant" | "system">;
|
||||
content: string;
|
||||
images?: string[];
|
||||
};
|
||||
role: StringWithAutocomplete<"user" | "assistant" | "system">
|
||||
content: string
|
||||
images?: string[]
|
||||
}
|
||||
|
||||
export interface OllamaGenerateRequestParams extends OllamaRequestParams {
|
||||
prompt: string;
|
||||
prompt: string
|
||||
}
|
||||
|
||||
export interface OllamaChatRequestParams extends OllamaRequestParams {
|
||||
messages: OllamaMessage[];
|
||||
messages: OllamaMessage[]
|
||||
}
|
||||
|
||||
export type BaseOllamaGenerationChunk = {
|
||||
model: string;
|
||||
created_at: string;
|
||||
done: boolean;
|
||||
total_duration?: number;
|
||||
load_duration?: number;
|
||||
prompt_eval_count?: number;
|
||||
prompt_eval_duration?: number;
|
||||
eval_count?: number;
|
||||
eval_duration?: number;
|
||||
};
|
||||
model: string
|
||||
created_at: string
|
||||
done: boolean
|
||||
total_duration?: number
|
||||
load_duration?: number
|
||||
prompt_eval_count?: number
|
||||
prompt_eval_duration?: number
|
||||
eval_count?: number
|
||||
eval_duration?: number
|
||||
}
|
||||
|
||||
export type OllamaGenerationChunk = BaseOllamaGenerationChunk & {
|
||||
response: string;
|
||||
};
|
||||
response: string
|
||||
}
|
||||
|
||||
export type OllamaChatGenerationChunk = BaseOllamaGenerationChunk & {
|
||||
message: OllamaMessage;
|
||||
};
|
||||
message: OllamaMessage
|
||||
}
|
||||
|
||||
export type OllamaCallOptions = BaseLanguageModelCallOptions & {
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
|
||||
async function* createOllamaStream(
|
||||
url: string,
|
||||
params: OllamaRequestParams,
|
||||
options: OllamaCallOptions
|
||||
) {
|
||||
let formattedUrl = url;
|
||||
let formattedUrl = url
|
||||
if (formattedUrl.startsWith("http://localhost:")) {
|
||||
// Node 18 has issues with resolving "localhost"
|
||||
// See https://github.com/node-fetch/node-fetch/issues/1624
|
||||
formattedUrl = formattedUrl.replace(
|
||||
"http://localhost:",
|
||||
"http://127.0.0.1:"
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
const customHeaders = await getCustomOllamaHeaders()
|
||||
|
||||
const response = await fetch(formattedUrl, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(params),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...options.headers,
|
||||
...customHeaders
|
||||
},
|
||||
signal: options.signal,
|
||||
});
|
||||
signal: options.signal
|
||||
})
|
||||
if (!response.ok) {
|
||||
let error;
|
||||
const responseText = await response.text();
|
||||
let error
|
||||
const responseText = await response.text()
|
||||
try {
|
||||
const json = JSON.parse(responseText);
|
||||
const json = JSON.parse(responseText)
|
||||
error = new Error(
|
||||
`Ollama call failed with status code ${response.status}: ${json.error}`
|
||||
);
|
||||
)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
} catch (e: any) {
|
||||
error = new Error(
|
||||
`Ollama call failed with status code ${response.status}: ${responseText}`
|
||||
);
|
||||
)
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(error as any).response = response;
|
||||
throw error;
|
||||
;(error as any).response = response
|
||||
throw error
|
||||
}
|
||||
if (!response.body) {
|
||||
throw new Error(
|
||||
"Could not begin Ollama stream. Please check the given URL and try again."
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
const stream = IterableReadableStream.fromReadableStream(response.body);
|
||||
const stream = IterableReadableStream.fromReadableStream(response.body)
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let extra = "";
|
||||
const decoder = new TextDecoder()
|
||||
let extra = ""
|
||||
for await (const chunk of stream) {
|
||||
const decoded = extra + decoder.decode(chunk);
|
||||
const lines = decoded.split("\n");
|
||||
extra = lines.pop() || "";
|
||||
const decoded = extra + decoder.decode(chunk)
|
||||
const lines = decoded.split("\n")
|
||||
extra = lines.pop() || ""
|
||||
for (const line of lines) {
|
||||
try {
|
||||
yield JSON.parse(line);
|
||||
yield JSON.parse(line)
|
||||
} catch (e) {
|
||||
console.warn(`Received a non-JSON parseable chunk: ${line}`);
|
||||
console.warn(`Received a non-JSON parseable chunk: ${line}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,7 +194,7 @@ export async function* createOllamaGenerateStream(
|
||||
params: OllamaGenerateRequestParams,
|
||||
options: OllamaCallOptions
|
||||
): AsyncGenerator<OllamaGenerationChunk> {
|
||||
yield* createOllamaStream(`${baseUrl}/api/generate`, params, options);
|
||||
yield* createOllamaStream(`${baseUrl}/api/generate`, params, options)
|
||||
}
|
||||
|
||||
export async function* createOllamaChatStream(
|
||||
@@ -197,13 +202,12 @@ export async function* createOllamaChatStream(
|
||||
params: OllamaChatRequestParams,
|
||||
options: OllamaCallOptions
|
||||
): AsyncGenerator<OllamaChatGenerationChunk> {
|
||||
yield* createOllamaStream(`${baseUrl}/api/chat`, params, options);
|
||||
yield* createOllamaStream(`${baseUrl}/api/chat`, params, options)
|
||||
}
|
||||
|
||||
|
||||
export const parseKeepAlive = (keepAlive: any) => {
|
||||
if (keepAlive === "-1") {
|
||||
return -1
|
||||
}
|
||||
return keepAlive
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user