Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion packages/common/src/vscode-webui-bridge/webview-stub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,13 @@ const VSCodeHostStub = {
});
},
readModelList: async () => {
return Promise.resolve({} as ThreadSignalSerialization<DisplayModel[]>);
return Promise.resolve(
{} as {
modelList: ThreadSignalSerialization<DisplayModel[]>;
isLoading: ThreadSignalSerialization<boolean>;
reload: () => Promise<void>;
},
);
},
readUserStorage: async () => {
return Promise.resolve(
Expand Down
6 changes: 5 additions & 1 deletion packages/common/src/vscode-webui-bridge/webview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ export interface VSCodeHostApi {
...items: T[]
): Promise<T | undefined>;

readModelList(): Promise<ThreadSignalSerialization<DisplayModel[]>>;
readModelList(): Promise<{
modelList: ThreadSignalSerialization<DisplayModel[]>;
isLoading: ThreadSignalSerialization<boolean>;
reload: () => Promise<void>;
}>;

readUserStorage(): Promise<
ThreadSignalSerialization<Record<string, UserInfo>>
Expand Down
2 changes: 1 addition & 1 deletion packages/vendor-pochi/src/vendor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export class Pochi extends VendorBase {
}

override async fetchModels(): Promise<Record<string, ModelOptions>> {
if (!this.cachedModels) {
if (!this.cachedModels || Object.keys(this.cachedModels).length === 0) {
const apiClient: PochiApiClient = hc<PochiApi>(getServerBaseUrl());
const data = await withRetry(
async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export const Default: Story = {
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isValid: true,
reloadModels: async () => console.log("Refreshing models..."),
},
};

Expand All @@ -112,6 +113,7 @@ export const LoadingState: Story = {
onChange: (v) => console.log("Selected model:", v),
isLoading: true,
isValid: false,
reloadModels: async () => console.log("Refreshing models..."),
},
};

Expand All @@ -121,7 +123,21 @@ export const NoModels: Story = {
value: undefined,
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isFetching: false,
isValid: false,
reloadModels: async () => console.log("Refreshing models..."),
},
};

export const Refreshing: Story = {
args: {
models: [],
value: undefined,
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isFetching: true,
isValid: false,
reloadModels: async () => console.log("Refreshing models..."),
},
};

Expand All @@ -131,6 +147,8 @@ export const Invalid: Story = {
value: mockModels[0].models[0],
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isFetching: false,
isValid: false,
reloadModels: async () => console.log("Refreshing models..."),
},
};
36 changes: 30 additions & 6 deletions packages/vscode-webui/src/components/model-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuSeparator,
Expand All @@ -17,7 +16,12 @@ import {
} from "@/components/ui/hover-card";
import { Skeleton } from "@/components/ui/skeleton";
import { cn } from "@/lib/utils";
import { CheckIcon, ChevronDownIcon, TriangleAlertIcon } from "lucide-react";
import {
CheckIcon,
ChevronDownIcon,
RefreshCwIcon,
TriangleAlertIcon,
} from "lucide-react";
import { useTranslation } from "react-i18next";

import LoadingWrapper from "@/components/loading-wrapper";
Expand All @@ -32,17 +36,21 @@ interface ModelSelectProps {
value: ModelSelectValue | undefined;
onChange: (v: string) => void;
isLoading?: boolean;
isFetching?: boolean;
isValid?: boolean;
triggerClassName?: string;
reloadModels?: () => Promise<void>;
}

export function ModelSelect({
models,
value,
onChange,
isLoading,
isFetching,
isValid,
triggerClassName,
reloadModels,
}: ModelSelectProps) {
const { t } = useTranslation();

Expand Down Expand Up @@ -187,18 +195,34 @@ export function ModelSelect({
))}

{!!customModels?.flat().length && <DropdownMenuSeparator />}
<DropdownMenuItem asChild>
<div className="flex items-center justify-between gap-2 px-2">
<a
href="command:pochi.openCustomModelSettings"
target="_blank"
rel="noopener noreferrer"
className="flex cursor-pointer items-center gap-2 px-3 py-1"
className="group cursor-pointer px-3 py-2.5"
>
<span className="text-[var(--vscode-textLink-foreground)] text-xs">
<span className="text-[var(--vscode-textLink-foreground)] text-xs group-hover:underline ">
{t("modelSelect.manageCustomModels")}
</span>
</a>
</DropdownMenuItem>
<span
onClick={() => {
if (isFetching) {
return;
}
reloadModels?.();
}}
className="flex cursor-pointer items-center gap-1 px-3 py-2.5 text-[var(--vscode-textLink-foreground)] text-xs hover:underline"
>
<RefreshCwIcon
className={cn("size-3 opacity-0", {
"animate-spin opacity-100": isFetching,
})}
/>
{t("modelSelect.reload")}
</span>
</div>
</DropdownMenuRadioGroup>
</DropdownMenuContent>
</DropdownMenuPortal>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ export const ChatToolbar: React.FC<ChatToolbarProps> = ({
selectedModel,
selectedModelFromStore, // for fallback display
isLoading: isModelsLoading,
isFetching: isFetchingModels,
reload: reloadModels,
updateSelectedModelId,
} = useSelectedModels({ isSubTask });

Expand Down Expand Up @@ -332,8 +334,10 @@ export const ChatToolbar: React.FC<ChatToolbarProps> = ({
value={selectedModel || selectedModelFromStore}
models={groupedModels}
isLoading={isModelsLoading}
isFetching={isFetchingModels}
isValid={!!selectedModel}
onChange={updateSelectedModelId}
reloadModels={reloadModels}
/>
</div>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export const CreateTaskInput: React.FC<CreateTaskInputProps> = ({
selectedModel,
selectedModelFromStore, // for fallback display
isLoading: isModelsLoading,
isFetching: isFetchingModels,
reload: reloadModels,
updateSelectedModelId,
} = useSelectedModels({ isSubTask: false });

Expand Down Expand Up @@ -312,8 +314,10 @@ export const CreateTaskInput: React.FC<CreateTaskInputProps> = ({
value={selectedModel || selectedModelFromStore}
models={groupedModels}
isLoading={isModelsLoading}
isFetching={isFetchingModels}
isValid={!!selectedModel}
onChange={updateSelectedModelId}
reloadModels={reloadModels}
/>
</div>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ export function useSelectedModels(options?: UseSelectedModelsOptions) {
const { t } = useTranslation();
const isSubTask = options?.isSubTask ?? false;

const { modelList: models, isLoading } = useModelList(true);
const {
modelList: models,
isLoading,
isFetching,
reload,
} = useModelList(true);
const { selectedModel: selectedModelFromStore } = useSettingsStore();
const { updateSelectedModel, selectedModel: storedSelectedModel } =
useModelSelectionState(isSubTask);
Expand Down Expand Up @@ -97,6 +102,8 @@ export function useSelectedModels(options?: UseSelectedModelsOptions) {

return {
isLoading,
isFetching,
reload,
models,
groupedModels,
// model with full information
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "Model is unavailable or still loading",
"custom": "Custom",
"super": "Super",
"swift": "Swift"
"swift": "Swift",
"reload": "Reload"
},
"worktreeSelect": {
"selectWorktree": "Select Worktree",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/jp.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "モデルが利用できないか、読み込み中です。",
"custom": "カスタム",
"super": "加強",
"swift": "極速"
"swift": "極速",
"reload": "再読み込み"
},
"worktreeSelect": {
"selectWorktree": "ワークツリーを選択",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/ko.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "모델을 사용할 수 없거나 로드 중입니다.",
"custom": "사용자 정의",
"super": "가강",
"swift": "극속"
"swift": "극속",
"reload": "새로고침"
},
"worktreeSelect": {
"selectWorktree": "워크트리 선택",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "模型不可用或仍在加载中。",
"custom": "自定义",
"super": "加强",
"swift": "极速"
"swift": "极速",
"reload": "重新加载"
},
"worktreeSelect": {
"selectWorktree": "选择工作树",
Expand Down
24 changes: 16 additions & 8 deletions packages/vscode-webui/src/lib/hooks/use-model-list.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { useEnablePochiModels } from "@/features/settings";
import { vscodeHost } from "@/lib/vscode";
import type { DisplayModel } from "@getpochi/common/vscode-webui-bridge";
import { threadSignal } from "@quilted/threads/signals";
import { useQuery } from "@tanstack/react-query";
import { useMemo } from "react";

/** @useSignals this comment is needed to enable signals in this hook */
export const useModelList = (filterPochiModels: boolean) => {
const { data: modelListSignal, isLoading } = useQuery({
const { data, isLoading } = useQuery({
queryKey: ["modelList"],
queryFn: fetchModelList,
staleTime: Number.POSITIVE_INFINITY,
Expand All @@ -17,19 +16,28 @@ export const useModelList = (filterPochiModels: boolean) => {

const modelList = useMemo(() => {
return filterPochiModels
? modelListSignal?.value?.filter((model) => {
? data?.modelList?.value?.filter((model) => {
if (model.type === "vendor" && model.vendorId === "pochi") {
return !model.modelId.startsWith("pochi/") || enablePochiModels;
}
return true;
})
: modelListSignal?.value;
}, [filterPochiModels, modelListSignal?.value, enablePochiModels]);
: data?.modelList?.value;
}, [filterPochiModels, data?.modelList?.value, enablePochiModels]);

return { modelList, isLoading };
return {
modelList,
isLoading,
isFetching: !!data?.isLoading.value,
reload: data?.reload,
};
};

async function fetchModelList() {
const signal = threadSignal<DisplayModel[]>(await vscodeHost.readModelList());
return signal;
const result = await vscodeHost.readModelList();
return {
modelList: threadSignal(result.modelList),
isLoading: threadSignal(result.isLoading),
reload: result.reload,
};
}
14 changes: 10 additions & 4 deletions packages/vscode/src/integrations/webview/vscode-host-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,16 @@ export class VSCodeHostImpl implements VSCodeHostApi, vscode.Disposable {
);
};

readModelList = async (): Promise<
ThreadSignalSerialization<DisplayModel[]>
> => {
return ThreadSignal.serialize(this.modelList.modelList);
readModelList = async (): Promise<{
modelList: ThreadSignalSerialization<DisplayModel[]>;
isLoading: ThreadSignalSerialization<boolean>;
reload: () => Promise<void>;
}> => {
return {
modelList: ThreadSignal.serialize(this.modelList.modelList),
isLoading: ThreadSignal.serialize(this.modelList.isLoading),
reload: this.modelList.reload,
};
};

readUserStorage = async (): Promise<
Expand Down
10 changes: 10 additions & 0 deletions packages/vscode/src/lib/model-list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const logger = getLogger("ModelList");
export class ModelList implements vscode.Disposable {
dispose: () => void;
readonly modelList: Signal<DisplayModel[]> = signal([]);
readonly isLoading: Signal<boolean> = signal(false);

constructor() {
this.dispose = watchPochiConfigKeys(["providers", "vendors"], () => {
Expand All @@ -26,14 +27,21 @@ export class ModelList implements vscode.Disposable {
});
}

reload = async (): Promise<void> => {
this.modelList.value = await this.fetchModelList();
};

private async fetchModelList(): Promise<DisplayModel[]> {
this.isLoading.value = true;

const modelList: DisplayModel[] = [];

const vendors = getVendors();
// From vendors
for (const [vendorId, vendor] of Object.entries(vendors)) {
if (vendor.authenticated) {
try {
logger.trace("fetch models", vendorId);
const models = await vendor.fetchModels();
for (const [modelId, options] of Object.entries(models)) {
modelList.push({
Expand Down Expand Up @@ -82,6 +90,8 @@ export class ModelList implements vscode.Disposable {
}
}

this.isLoading.value = false;

return modelList;
}
}