Skip to content

Flux Kontext UI support #8111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
FluxKontextModel = "FluxKontextModelField"
# endregion

# region Misc Field Types
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BaseModelType(str, Enum):
Imagen3 = "imagen3"
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"
FluxKontext = "flux-kontext"


class ModelType(str, Enum):
Expand Down
2 changes: 2 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@
"modelIncompatibleScaledBboxWidth": "Scaled bbox width is {{width}} but {{model}} requires multiple of {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Scaled bbox height is {{height}} but {{model}} requires multiple of {{multiple}}",
"fluxModelMultipleControlLoRAs": "Can only use 1 Control LoRA at a time",
"fluxKontextMultipleReferenceImages": "Can only use 1 Reference Image at a time with Flux Kontext",
"canvasIsFiltering": "Canvas is busy (filtering)",
"canvasIsTransforming": "Canvas is busy (transforming)",
"canvasIsRasterizing": "Canvas is busy (rasterizing)",
Expand Down Expand Up @@ -1337,6 +1338,7 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatch
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
Expand Down Expand Up @@ -59,6 +60,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import type {
import {
initialChatGPT4oReferenceImage,
initialControlNet,
initialFluxKontextReferenceImage,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
Expand Down Expand Up @@ -87,6 +88,12 @@ export const selectDefaultRefImageConfig = createSelector(
return referenceImage;
}

if (selectedMainModel?.base === 'flux-kontext') {
const referenceImage = deepClone(initialFluxKontextReferenceImage);
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
return referenceImage;
}

const { data } = query;
let model: IPAdapterModelConfig | null = null;
if (data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import { useAppSelector } from 'app/store/storeHooks';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsFluxKontext,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { selectActiveReferenceImageEntities } from 'features/controlLayers/store/selectors';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
import { useMemo } from 'react';
import type { Equals } from 'tsafe';
Expand All @@ -17,23 +19,28 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isFluxKontext = useAppSelector(selectIsFluxKontext);
const activeReferenceImageEntities = useAppSelector(selectActiveReferenceImageEntities);

const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
if (isFluxKontext) {
return activeReferenceImageEntities.length === 0;
}
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isImagen4 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isFluxKontext && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isFluxKontext, isChatGPT4o, activeReferenceImageEntities]);

return isEntityTypeEnabled;
};
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import {
initialChatGPT4oReferenceImage,
initialControlLoRA,
initialControlNet,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialIPAdapter,
initialT2IAdapter,
Expand Down Expand Up @@ -686,6 +687,16 @@ export const canvasSlice = createSlice({
return;
}

if (entity.ipAdapter.model.base === 'flux-kontext') {
// Switching to flux-kontext
entity.ipAdapter = {
...initialFluxKontextReferenceImage,
image: entity.ipAdapter.image,
model: entity.ipAdapter.model,
};
return;
}

if (entity.ipAdapter.model.type === 'flux_redux') {
// Switching to flux_redux
entity.ipAdapter = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ export const selectIsCogView4 = createParamsSelector((params) => params.model?.b
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectIsFluxKontext = createParamsSelector((params) => params.model?.base === 'flux-kontext');

export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ const zChatGPT4oReferenceImageConfig = z.object({
});
export type ChatGPT4oReferenceImageConfig = z.infer<typeof zChatGPT4oReferenceImageConfig>;

const zFluxKontextReferenceImageConfig = z.object({
type: z.literal('flux_kontext_reference_image'),
image: zImageWithDims.nullable(),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type FluxKontextReferenceImageConfig = z.infer<typeof zFluxKontextReferenceImageConfig>;

const zCanvasEntityBase = z.object({
id: zId,
name: zName,
Expand All @@ -268,7 +275,12 @@ const zCanvasEntityBase = z.object({
const zCanvasReferenceImageState = zCanvasEntityBase.extend({
type: z.literal('reference_image'),
// This should be named `referenceImage` but we need to keep it as `ipAdapter` for backwards compatibility
ipAdapter: z.discriminatedUnion('type', [zIPAdapterConfig, zFLUXReduxConfig, zChatGPT4oReferenceImageConfig]),
ipAdapter: z.discriminatedUnion('type', [
zIPAdapterConfig,
zFLUXReduxConfig,
zChatGPT4oReferenceImageConfig,
zFluxKontextReferenceImageConfig,
]),
});
export type CanvasReferenceImageState = z.infer<typeof zCanvasReferenceImageState>;

Expand All @@ -280,6 +292,9 @@ export const isFLUXReduxConfig = (config: CanvasReferenceImageState['ipAdapter']
export const isChatGPT4oReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is ChatGPT4oReferenceImageConfig => config.type === 'chatgpt_4o_reference_image';
export const isFluxKontextReferenceImageConfig = (
config: CanvasReferenceImageState['ipAdapter']
): config is FluxKontextReferenceImageConfig => config.type === 'flux_kontext_reference_image';

const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
Expand Down Expand Up @@ -416,6 +431,10 @@ export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);
export const isChatGPT4oAspectRatioID = (v: unknown): v is z.infer<typeof zChatGPT4oAspectRatioID> =>
zChatGPT4oAspectRatioID.safeParse(v).success;

export const zFluxKontextAspectRatioID = z.enum(['16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFluxKontextAspectRatioID> =>
zFluxKontextAspectRatioID.safeParse(v).success;

export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import type {
ChatGPT4oReferenceImageConfig,
ControlLoRAConfig,
ControlNetConfig,
FluxKontextReferenceImageConfig,
FLUXReduxConfig,
ImageWithDims,
IPAdapterConfig,
Expand Down Expand Up @@ -83,6 +84,11 @@ export const initialChatGPT4oReferenceImage: ChatGPT4oReferenceImageConfig = {
image: null,
model: null,
};
export const initialFluxKontextReferenceImage: FluxKontextReferenceImageConfig = {
type: 'flux_kontext_reference_image',
image: null,
model: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',
'flux-kontext': 'pink',
};

const ModelBaseBadge = ({ base }: Props) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { FloatFieldSlider } from 'features/nodes/components/flow/nodes/Invocatio
import ChatGPT4oModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ChatGPT4oModelFieldInputComponent';
import { FloatFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatFieldCollectionInputComponent';
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import FluxKontextModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxKontextModelFieldInputComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
Expand Down Expand Up @@ -50,6 +51,8 @@ import {
isFloatFieldInputTemplate,
isFloatGeneratorFieldInputInstance,
isFloatGeneratorFieldInputTemplate,
isFluxKontextModelFieldInputInstance,
isFluxKontextModelFieldInputTemplate,
isFluxMainModelFieldInputInstance,
isFluxMainModelFieldInputTemplate,
isFluxReduxModelFieldInputInstance,
Expand Down Expand Up @@ -417,6 +420,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isFluxKontextModelFieldInputTemplate(template)) {
if (!isFluxKontextModelFieldInputInstance(field)) {
return null;
}
return <FluxKontextModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}

if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldFluxKontextModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FluxKontextModelFieldInputInstance,
FluxKontextModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useFluxKontextModels } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';

import type { FieldComponentProps } from './types';

const FluxKontextModelFieldInputComponent = (
props: FieldComponentProps<FluxKontextModelFieldInputInstance, FluxKontextModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();

const [modelConfigs, { isLoading }] = useFluxKontextModels();

const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldFluxKontextModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);

return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};

export default memo(FluxKontextModelFieldInputComponent);
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'google_imagen4_generate',
'chatgpt_create_image',
'chatgpt_edit_image',
'flux_kontext_generate_image',
];

export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import type {
FieldValue,
FloatFieldValue,
FloatGeneratorFieldValue,
FluxKontextModelFieldValue,
FluxReduxModelFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
Expand Down Expand Up @@ -75,6 +76,7 @@ import {
zFloatFieldCollectionValue,
zFloatFieldValue,
zFloatGeneratorFieldValue,
zFluxKontextModelFieldValue,
zFluxReduxModelFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
Expand Down Expand Up @@ -527,6 +529,9 @@ export const nodesSlice = createSlice({
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
fieldFluxKontextModelValueChanged: (state, action: FieldValueAction<FluxKontextModelFieldValue>) => {
fieldValueReducer(state, action, zFluxKontextModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
Expand Down Expand Up @@ -697,6 +702,7 @@ export const {
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFluxKontextModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
fieldStringGeneratorValueChanged,
Expand Down
2 changes: 2 additions & 0 deletions invokeai/frontend/web/src/features/nodes/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ const zBaseModel = z.enum([
'imagen3',
'imagen4',
'chatgpt-4o',
'flux-kontext',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum([
Expand All @@ -90,6 +91,7 @@ export const zMainModelBase = z.enum([
'imagen3',
'imagen4',
'chatgpt-4o',
'flux-kontext',
]);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
Expand Down
Loading