Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/common/src/vscode-webui-bridge/webview-stub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ const VSCodeHostStub = {
readModelList: async () => {
return Promise.resolve({} as ThreadSignalSerialization<DisplayModel[]>);
},
refreshModelList: async () => {
return Promise.resolve();
},
readUserStorage: async () => {
return Promise.resolve(
{} as ThreadSignalSerialization<Record<string, UserInfo>>,
Expand Down
5 changes: 5 additions & 0 deletions packages/common/src/vscode-webui-bridge/webview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ export interface VSCodeHostApi {

readModelList(): Promise<ThreadSignalSerialization<DisplayModel[]>>;

/**
* Manually refresh the model list from vendors and providers
*/
refreshModelList(): Promise<void>;

readUserStorage(): Promise<
ThreadSignalSerialization<Record<string, UserInfo>>
>;
Expand Down
2 changes: 1 addition & 1 deletion packages/vendor-pochi/src/vendor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ export class Pochi extends VendorBase {
}

override async fetchModels(): Promise<Record<string, ModelOptions>> {
if (!this.cachedModels) {
if (!this.cachedModels || Object.keys(this.cachedModels).length === 0) {
const apiClient: PochiApiClient = hc<PochiApi>(getServerBaseUrl());
const data = await withRetry(
async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export const Default: Story = {
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isValid: true,
refreshModels: () => console.log("Refreshing models..."),
},
};

Expand All @@ -112,6 +113,7 @@ export const LoadingState: Story = {
onChange: (v) => console.log("Selected model:", v),
isLoading: true,
isValid: false,
refreshModels: () => console.log("Refreshing models..."),
},
};

Expand All @@ -121,7 +123,21 @@ export const NoModels: Story = {
value: undefined,
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isRefreshing: false,
isValid: false,
refreshModels: () => console.log("Refreshing models..."),
},
};

export const Refreshing: Story = {
args: {
models: [],
value: undefined,
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isRefreshing: true,
isValid: false,
refreshModels: () => console.log("Refreshing models..."),
},
};

Expand All @@ -131,6 +147,8 @@ export const Invalid: Story = {
value: mockModels[0].models[0],
onChange: (v) => console.log("Selected model:", v),
isLoading: false,
isRefreshing: false,
isValid: false,
refreshModels: () => console.log("Refreshing models..."),
},
};
38 changes: 37 additions & 1 deletion packages/vscode-webui/src/components/model-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ import {
} from "@/components/ui/hover-card";
import { Skeleton } from "@/components/ui/skeleton";
import { cn } from "@/lib/utils";
import { CheckIcon, ChevronDownIcon, TriangleAlertIcon } from "lucide-react";
import {
CheckIcon,
ChevronDownIcon,
RefreshCwIcon,
TriangleAlertIcon,
} from "lucide-react";
import { useTranslation } from "react-i18next";

import LoadingWrapper from "@/components/loading-wrapper";
import type { ModelGroups } from "@/features/settings";
import { useUserStorage } from "@/lib/hooks/use-user-storage";
import type { DisplayModel } from "@getpochi/common/vscode-webui-bridge";
import { DropdownMenuPortal } from "@radix-ui/react-dropdown-menu";

Expand All @@ -32,23 +38,37 @@ interface ModelSelectProps {
value: ModelSelectValue | undefined;
onChange: (v: string) => void;
isLoading?: boolean;
isRefreshing?: boolean;
isValid?: boolean;
triggerClassName?: string;
refreshModels: () => void;
}

export function ModelSelect({
models,
value,
onChange,
isLoading,
isRefreshing,
isValid,
triggerClassName,
refreshModels,
}: ModelSelectProps) {
const { t } = useTranslation();
const {
users: { pochi: user } = {},
} = useUserStorage();

const hostedModels = models?.filter((x) => !x.isCustom);
const customModels = models?.filter((x) => x.isCustom);

const shouldShowReloadButton =
!hostedModels ||
((hostedModels.find((g) => g.title === "Super")?.models.length ?? 0) ===
0 &&
(hostedModels.find((g) => g.title === "Swift")?.models.length ?? 0) ===
0);

const onSelectModel = (v: DisplayModel) => {
onChange(v.id);
};
Expand Down Expand Up @@ -113,6 +133,22 @@ export function ModelSelect({
alignOffset={6}
className="dropdown-menu max-h-[32vh] min-w-[18rem] animate-in overflow-y-auto overflow-x-hidden rounded-md border bg-background p-2 text-popover-foreground shadow"
>
{!!user && !isLoading && shouldShowReloadButton && (
<div className="flex justify-center py-4">
<Button
onClick={() => refreshModels()}
variant="outline"
size="sm"
className="gap-2"
disabled={isRefreshing}
>
<RefreshCwIcon
className={cn("size-4", isRefreshing && "animate-spin")}
/>
{t("modelSelect.reload")}
</Button>
</div>
)}
<DropdownMenuRadioGroup>
{hostedModels
?.filter((group) => group.models.length > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ export const ChatToolbar: React.FC<ChatToolbarProps> = ({
selectedModel,
selectedModelFromStore, // for fallback display
isLoading: isModelsLoading,
isRefreshing: isRefreshingModels,
updateSelectedModelId,
refreshModels,
} = useSelectedModels({ isSubTask });

// Use the unified attachment upload hook
Expand Down Expand Up @@ -332,8 +334,10 @@ export const ChatToolbar: React.FC<ChatToolbarProps> = ({
value={selectedModel || selectedModelFromStore}
models={groupedModels}
isLoading={isModelsLoading}
isRefreshing={isRefreshingModels}
isValid={!!selectedModel}
onChange={updateSelectedModelId}
refreshModels={refreshModels}
/>
</div>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ export const CreateTaskInput: React.FC<CreateTaskInputProps> = ({
selectedModel,
selectedModelFromStore, // for fallback display
isLoading: isModelsLoading,
isRefreshing: isRefreshingModels,
updateSelectedModelId,
refreshModels,
} = useSelectedModels({ isSubTask: false });

// Use the unified attachment upload hook
Expand Down Expand Up @@ -312,8 +314,10 @@ export const CreateTaskInput: React.FC<CreateTaskInputProps> = ({
value={selectedModel || selectedModelFromStore}
models={groupedModels}
isLoading={isModelsLoading}
isRefreshing={isRefreshingModels}
isValid={!!selectedModel}
onChange={updateSelectedModelId}
refreshModels={refreshModels}
/>
</div>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ export function useSelectedModels(options?: UseSelectedModelsOptions) {
const { t } = useTranslation();
const isSubTask = options?.isSubTask ?? false;

const { modelList: models, isLoading } = useModelList(true);
const {
modelList: models,
isLoading,
isRefreshing,
refresh,
} = useModelList(true);
const { selectedModel: selectedModelFromStore } = useSettingsStore();
const { updateSelectedModel, selectedModel: storedSelectedModel } =
useModelSelectionState(isSubTask);
Expand Down Expand Up @@ -97,12 +102,14 @@ export function useSelectedModels(options?: UseSelectedModelsOptions) {

return {
isLoading,
isRefreshing,
models,
groupedModels,
// model with full information
selectedModel,
updateSelectedModelId,
// model for fallback display
selectedModelFromStore,
refreshModels: refresh,
};
}
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "Model is unavailable or still loading",
"custom": "Custom",
"super": "Super",
"swift": "Swift"
"swift": "Swift",
"reload": "Reload"
},
"worktreeSelect": {
"selectWorktree": "Select Worktree",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/jp.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "モデルが利用できないか、読み込み中です。",
"custom": "カスタム",
"super": "加強",
"swift": "極速"
"swift": "極速",
"reload": "再読み込み"
},
"worktreeSelect": {
"selectWorktree": "ワークツリーを選択",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/ko.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "모델을 사용할 수 없거나 로드 중입니다.",
"custom": "사용자 정의",
"super": "가강",
"swift": "극속"
"swift": "극속",
"reload": "새로고침"
},
"worktreeSelect": {
"selectWorktree": "워크트리 선택",
Expand Down
3 changes: 2 additions & 1 deletion packages/vscode-webui/src/i18n/locales/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
"modelUnavailable": "模型不可用或仍在加载中。",
"custom": "自定义",
"super": "加强",
"swift": "极速"
"swift": "极速",
"reload": "重新加载"
},
"worktreeSelect": {
"selectWorktree": "选择工作树",
Expand Down
16 changes: 14 additions & 2 deletions packages/vscode-webui/src/lib/hooks/use-model-list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { vscodeHost } from "@/lib/vscode";
import type { DisplayModel } from "@getpochi/common/vscode-webui-bridge";
import { threadSignal } from "@quilted/threads/signals";
import { useQuery } from "@tanstack/react-query";
import { useMemo } from "react";
import { useCallback, useMemo, useState } from "react";

/** @useSignals this comment is needed to enable signals in this hook */
export const useModelList = (filterPochiModels: boolean) => {
Expand All @@ -13,6 +13,8 @@ export const useModelList = (filterPochiModels: boolean) => {
staleTime: Number.POSITIVE_INFINITY,
});

const [isRefreshing, setIsRefreshing] = useState(false);

const enablePochiModels = useEnablePochiModels();

const modelList = useMemo(() => {
Expand All @@ -26,7 +28,17 @@ export const useModelList = (filterPochiModels: boolean) => {
: modelListSignal?.value;
}, [filterPochiModels, modelListSignal?.value, enablePochiModels]);

return { modelList, isLoading };
const refresh = useCallback(async () => {
setIsRefreshing(true);
try {
await vscodeHost.refreshModelList();
// The signal will automatically update, so we don't need to call refetch
} finally {
setIsRefreshing(false);
}
}, []);

return { modelList, isLoading, isRefreshing, refresh };
};

async function fetchModelList() {
Expand Down
1 change: 1 addition & 0 deletions packages/vscode-webui/src/lib/vscode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ function createVSCodeHost(): VSCodeHostApi {
"showInformationMessage",
"readVisibleTerminals",
"readModelList",
"refreshModelList",
"readUserStorage",
"readCustomAgents",
"openTaskInPanel",
Expand Down
4 changes: 4 additions & 0 deletions packages/vscode/src/integrations/webview/vscode-host-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ export class VSCodeHostImpl implements VSCodeHostApi, vscode.Disposable {
return ThreadSignal.serialize(this.modelList.modelList);
};

refreshModelList = async (): Promise<void> => {
await this.modelList.refresh();
};

readUserStorage = async (): Promise<
ThreadSignalSerialization<Record<string, UserInfo>>
> => {
Expand Down
5 changes: 5 additions & 0 deletions packages/vscode/src/lib/model-list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export class ModelList implements vscode.Disposable {
});
}

async refresh(): Promise<void> {
this.modelList.value = await this.fetchModelList();
}

private async fetchModelList(): Promise<DisplayModel[]> {
const modelList: DisplayModel[] = [];

Expand All @@ -34,6 +38,7 @@ export class ModelList implements vscode.Disposable {
for (const [vendorId, vendor] of Object.entries(vendors)) {
if (vendor.authenticated) {
try {
logger.trace("fetch models", vendorId);
const models = await vendor.fetchModels();
for (const [modelId, options] of Object.entries(models)) {
modelList.push({
Expand Down