feat(vision): add vision chat mode

- Add new "vision" chat mode to the application
- Implement the `visionChatMode` function to handle vision-based chat interactions
- Update the UI to include a new button to toggle the vision chat mode
- Add new translations for the "vision" chat mode tooltip
- Disable certain UI elements when the vision chat mode is active
This commit is contained in:
n4ze3m 2024-11-23 14:04:57 +05:30
parent edc5380a76
commit 2c12b17dda
5 changed files with 368 additions and 47 deletions

View File

@ -23,7 +23,8 @@
"speechToText": "Speech to Text", "speechToText": "Speech to Text",
"uploadImage": "Upload Image", "uploadImage": "Upload Image",
"stopStreaming": "Stop Streaming", "stopStreaming": "Stop Streaming",
"knowledge": "Knowledge" "knowledge": "Knowledge",
"vision": "[Experimental] Vision Chat"
}, },
"sendWhenEnter": "Send when Enter pressed", "sendWhenEnter": "Send when Enter pressed",
"welcome": "Hello! How can I help you today?" "welcome": "Hello! How can I help you today?"

View File

@ -7,7 +7,14 @@ import { toBase64 } from "~/libs/to-base64"
import { Checkbox, Dropdown, Image, Switch, Tooltip } from "antd" import { Checkbox, Dropdown, Image, Switch, Tooltip } from "antd"
import { useWebUI } from "~/store/webui" import { useWebUI } from "~/store/webui"
import { defaultEmbeddingModelForRag } from "~/services/ollama" import { defaultEmbeddingModelForRag } from "~/services/ollama"
import { ImageIcon, MicIcon, StopCircleIcon, X } from "lucide-react" import {
ImageIcon,
MicIcon,
StopCircleIcon,
X,
EyeIcon,
EyeOffIcon
} from "lucide-react"
import { useTranslation } from "react-i18next" import { useTranslation } from "react-i18next"
import { ModelSelect } from "@/components/Common/ModelSelect" import { ModelSelect } from "@/components/Common/ModelSelect"
import { useSpeechRecognition } from "@/hooks/useSpeechRecognition" import { useSpeechRecognition } from "@/hooks/useSpeechRecognition"
@ -36,7 +43,7 @@ export const SidepanelForm = ({ dropedFile }: Props) => {
resetTranscript, resetTranscript,
start: startListening, start: startListening,
stop: stopSpeechRecognition, stop: stopSpeechRecognition,
supported: browserSupportsSpeechRecognition, supported: browserSupportsSpeechRecognition
} = useSpeechRecognition() } = useSpeechRecognition()
const stopListening = async () => { const stopListening = async () => {
@ -237,7 +244,10 @@ export const SidepanelForm = ({ dropedFile }: Props) => {
} }
} }
await stopListening() await stopListening()
if (value.message.trim().length === 0 && value.image.length === 0) { if (
value.message.trim().length === 0 &&
value.image.length === 0
) {
return return
} }
form.reset() form.reset()
@ -281,6 +291,7 @@ export const SidepanelForm = ({ dropedFile }: Props) => {
{...form.getInputProps("message")} {...form.getInputProps("message")}
/> />
<div className="flex mt-4 justify-end gap-3"> <div className="flex mt-4 justify-end gap-3">
{chatMode !== "vision" && (
<Tooltip title={t("tooltip.searchInternet")}> <Tooltip title={t("tooltip.searchInternet")}>
<button <button
type="button" type="button"
@ -295,6 +306,7 @@ export const SidepanelForm = ({ dropedFile }: Props) => {
)} )}
</button> </button>
</Tooltip> </Tooltip>
)}
<ModelSelect /> <ModelSelect />
{browserSupportsSpeechRecognition && ( {browserSupportsSpeechRecognition && (
<Tooltip title={t("tooltip.speechToText")}> <Tooltip title={t("tooltip.speechToText")}>
@ -323,13 +335,35 @@ export const SidepanelForm = ({ dropedFile }: Props) => {
</button> </button>
</Tooltip> </Tooltip>
)} )}
<Tooltip title={t("tooltip.vision")}>
<button
type="button"
onClick={() => {
if (chatMode === "vision") {
setChatMode("normal")
} else {
setChatMode("vision")
}
}}
disabled={chatMode === "rag"}
className={`flex items-center justify-center dark:text-gray-300 ${
chatMode === "rag" ? "hidden" : "block"
} disabled:opacity-50`}>
{chatMode === "vision" ? (
<EyeIcon className="h-5 w-5" />
) : (
<EyeOffIcon className="h-5 w-5" />
)}
</button>
</Tooltip>
<Tooltip title={t("tooltip.uploadImage")}> <Tooltip title={t("tooltip.uploadImage")}>
<button <button
type="button" type="button"
onClick={() => { onClick={() => {
inputRef.current?.click() inputRef.current?.click()
}} }}
className={`flex items-center justify-center dark:text-gray-300 ${ disabled={chatMode === "vision"}
className={`flex items-center justify-center disabled:opacity-50 dark:text-gray-300 ${
chatMode === "rag" ? "hidden" : "block" chatMode === "rag" ? "hidden" : "block"
}`}> }`}>
<ImageIcon className="h-5 w-5" /> <ImageIcon className="h-5 w-5" />

View File

@ -36,6 +36,7 @@ import { humanMessageFormatter } from "@/utils/human-message"
import { pageAssistEmbeddingModel } from "@/models/embedding" import { pageAssistEmbeddingModel } from "@/models/embedding"
import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore" import { PageAssistVectorStore } from "@/libs/PageAssistVectorStore"
import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore" import { PAMemoryVectorStore } from "@/libs/PAMemoryVectorStore"
import { getScreenshotFromCurrentTab } from "@/libs/get-screenshot"
export const useMessage = () => { export const useMessage = () => {
const { const {
@ -136,8 +137,9 @@ export const useMessage = () => {
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu,
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
let newMessage: Message[] = [] let newMessage: Message[] = []
@ -265,9 +267,11 @@ export const useMessage = () => {
userDefaultModelSettings?.numCtx, userDefaultModelSettings?.numCtx,
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ??
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, userDefaultModelSettings?.numGpu,
numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
const response = await questionOllama.invoke(promptForQuestion) const response = await questionOllama.invoke(promptForQuestion)
query = response.content.toString() query = response.content.toString()
@ -342,9 +346,7 @@ export const useMessage = () => {
signal: signal, signal: signal,
callbacks: [ callbacks: [
{ {
handleLLMEnd( handleLLMEnd(output: any): any {
output: any,
): any {
try { try {
generationInfo = output?.generations?.[0][0]?.generationInfo generationInfo = output?.generations?.[0][0]?.generationInfo
} catch (e) { } catch (e) {
@ -450,6 +452,236 @@ export const useMessage = () => {
} }
} }
const visionChatMode = async (
message: string,
image: string,
isRegenerate: boolean,
messages: Message[],
history: ChatHistory,
signal: AbortSignal
) => {
setStreaming(true)
const url = await getOllamaURL()
const userDefaultModelSettings = await getAllDefaultModelSettings()
const ollama = await pageAssistModel({
model: selectedModel!,
baseUrl: cleanUrl(url),
keepAlive:
currentChatModelSettings?.keepAlive ??
userDefaultModelSettings?.keepAlive,
temperature:
currentChatModelSettings?.temperature ??
userDefaultModelSettings?.temperature,
topK: currentChatModelSettings?.topK ?? userDefaultModelSettings?.topK,
topP: currentChatModelSettings?.topP ?? userDefaultModelSettings?.topP,
numCtx:
currentChatModelSettings?.numCtx ?? userDefaultModelSettings?.numCtx,
seed: currentChatModelSettings?.seed,
numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu,
numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
})
let newMessage: Message[] = []
let generateMessageId = generateID()
if (!isRegenerate) {
newMessage = [
...messages,
{
isBot: false,
name: "You",
message,
sources: [],
images: []
},
{
isBot: true,
name: selectedModel,
message: "▋",
sources: [],
id: generateMessageId
}
]
} else {
newMessage = [
...messages,
{
isBot: true,
name: selectedModel,
message: "▋",
sources: [],
id: generateMessageId
}
]
}
setMessages(newMessage)
let fullText = ""
let contentToSave = ""
try {
const prompt = await systemPromptForNonRag()
const selectedPrompt = await getPromptById(selectedSystemPrompt)
const applicationChatHistory = generateHistory(history, selectedModel)
const data = await getScreenshotFromCurrentTab()
console.log(
data?.success
? `[PageAssist] Screenshot is taken`
: `[PageAssist] Screenshot is not taken`
)
const visionImage = data?.screenshot || ""
if (visionImage === "") {
throw new Error(
"Please close and reopen the side panel. This is a bug that will be fixed soon."
)
}
if (prompt && !selectedPrompt) {
applicationChatHistory.unshift(
new SystemMessage({
content: prompt
})
)
}
if (selectedPrompt) {
applicationChatHistory.unshift(
new SystemMessage({
content: selectedPrompt.content
})
)
}
let humanMessage = humanMessageFormatter({
content: [
{
text: message,
type: "text"
},
{
image_url: visionImage,
type: "image_url"
}
],
model: selectedModel
})
let generationInfo: any | undefined = undefined
const chunks = await ollama.stream(
[...applicationChatHistory, humanMessage],
{
signal: signal,
callbacks: [
{
handleLLMEnd(output: any): any {
try {
generationInfo = output?.generations?.[0][0]?.generationInfo
} catch (e) {
console.log("handleLLMEnd error", e)
}
}
}
]
}
)
let count = 0
for await (const chunk of chunks) {
contentToSave += chunk?.content
fullText += chunk?.content
if (count === 0) {
setIsProcessing(true)
}
setMessages((prev) => {
return prev.map((message) => {
if (message.id === generateMessageId) {
return {
...message,
message: fullText + "▋"
}
}
return message
})
})
count++
}
setMessages((prev) => {
return prev.map((message) => {
if (message.id === generateMessageId) {
return {
...message,
message: fullText,
generationInfo
}
}
return message
})
})
setHistory([
...history,
{
role: "user",
content: message
},
{
role: "assistant",
content: fullText
}
])
await saveMessageOnSuccess({
historyId,
setHistoryId,
isRegenerate,
selectedModel: selectedModel,
message,
image,
fullText,
source: [],
message_source: "copilot",
generationInfo
})
setIsProcessing(false)
setStreaming(false)
} catch (e) {
const errorSave = await saveMessageOnError({
e,
botMessage: fullText,
history,
historyId,
image,
selectedModel,
setHistory,
setHistoryId,
userMessage: message,
isRegenerating: isRegenerate,
message_source: "copilot"
})
if (!errorSave) {
notification.error({
message: t("error"),
description: e?.message || t("somethingWentWrong")
})
}
setIsProcessing(false)
setStreaming(false)
setIsProcessing(false)
setStreaming(false)
setIsEmbedding(false)
} finally {
setAbortController(null)
setEmbeddingController(null)
}
}
const normalChatMode = async ( const normalChatMode = async (
message: string, message: string,
image: string, image: string,
@ -482,8 +714,9 @@ export const useMessage = () => {
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu,
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
let newMessage: Message[] = [] let newMessage: Message[] = []
@ -577,9 +810,7 @@ export const useMessage = () => {
signal: signal, signal: signal,
callbacks: [ callbacks: [
{ {
handleLLMEnd( handleLLMEnd(output: any): any {
output: any,
): any {
try { try {
generationInfo = output?.generations?.[0][0]?.generationInfo generationInfo = output?.generations?.[0][0]?.generationInfo
} catch (e) { } catch (e) {
@ -711,8 +942,9 @@ export const useMessage = () => {
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu,
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
let newMessage: Message[] = [] let newMessage: Message[] = []
@ -787,9 +1019,11 @@ export const useMessage = () => {
userDefaultModelSettings?.numCtx, userDefaultModelSettings?.numCtx,
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ??
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, userDefaultModelSettings?.numGpu,
numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
const response = await questionOllama.invoke(promptForQuestion) const response = await questionOllama.invoke(promptForQuestion)
query = response.content.toString() query = response.content.toString()
@ -842,9 +1076,7 @@ export const useMessage = () => {
signal: signal, signal: signal,
callbacks: [ callbacks: [
{ {
handleLLMEnd( handleLLMEnd(output: any): any {
output: any,
): any {
try { try {
generationInfo = output?.generations?.[0][0]?.generationInfo generationInfo = output?.generations?.[0][0]?.generationInfo
} catch (e) { } catch (e) {
@ -977,8 +1209,9 @@ export const useMessage = () => {
seed: currentChatModelSettings?.seed, seed: currentChatModelSettings?.seed,
numGpu: numGpu:
currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu, currentChatModelSettings?.numGpu ?? userDefaultModelSettings?.numGpu,
numPredict: currentChatModelSettings?.numPredict ?? userDefaultModelSettings?.numPredict, numPredict:
currentChatModelSettings?.numPredict ??
userDefaultModelSettings?.numPredict
}) })
let newMessage: Message[] = [] let newMessage: Message[] = []
@ -1052,9 +1285,7 @@ export const useMessage = () => {
signal: signal, signal: signal,
callbacks: [ callbacks: [
{ {
handleLLMEnd( handleLLMEnd(output: any): any {
output: any,
): any {
try { try {
generationInfo = output?.generations?.[0][0]?.generationInfo generationInfo = output?.generations?.[0][0]?.generationInfo
} catch (e) { } catch (e) {
@ -1216,6 +1447,15 @@ export const useMessage = () => {
signal signal
) )
} }
} else if (chatMode === "vision") {
await visionChatMode(
message,
image,
isRegenerate,
chatHistory || messages,
memory || history,
signal
)
} else { } else {
const newEmbeddingController = new AbortController() const newEmbeddingController = new AbortController()
let embeddingSignal = newEmbeddingController.signal let embeddingSignal = newEmbeddingController.signal

View File

@ -0,0 +1,46 @@
const captureVisibleTab = () => {
const result = new Promise<string>((resolve) => {
if (import.meta.env.BROWSER === "chrome") {
chrome.tabs.query({ active: true, currentWindow: true }, async (tabs) => {
const tab = tabs[0]
chrome.tabs.captureVisibleTab(null, { format: "png" }, (dataUrl) => {
resolve(dataUrl)
})
})
} else {
browser.tabs
.query({ active: true, currentWindow: true })
.then(async (tabs) => {
const dataUrl = (await Promise.race([
browser.tabs.captureVisibleTab(null, { format: "png" }),
new Promise((_, reject) =>
setTimeout(
() => reject(new Error("Screenshot capture timed out")),
10000
)
)
])) as string
resolve(dataUrl)
})
}
})
return result
}
export const getScreenshotFromCurrentTab = async () => {
try {
const screenshotDataUrl = await captureVisibleTab()
return {
success: true,
screenshot: screenshotDataUrl,
error: null
}
} catch (error) {
return {
success: false,
screenshot: null,
error:
error instanceof Error ? error.message : "Failed to capture screenshot"
}
}
}

View File

@ -32,8 +32,8 @@ type State = {
setIsProcessing: (isProcessing: boolean) => void setIsProcessing: (isProcessing: boolean) => void
selectedModel: string | null selectedModel: string | null
setSelectedModel: (selectedModel: string) => void setSelectedModel: (selectedModel: string) => void
chatMode: "normal" | "rag" chatMode: "normal" | "rag" | "vision"
setChatMode: (chatMode: "normal" | "rag") => void setChatMode: (chatMode: "normal" | "rag" | "vision") => void
isEmbedding: boolean isEmbedding: boolean
setIsEmbedding: (isEmbedding: boolean) => void setIsEmbedding: (isEmbedding: boolean) => void
speechToTextLanguage: string speechToTextLanguage: string