Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final String moderationModelName;
private final String imageModelName;
private final Optional<String> beanName;
private final DeclarativeAiServiceInputGuardrails inputGuardrails;
private final DeclarativeAiServiceOutputGuardrails outputGuardrails;
private final Integer maxSequentialToolInvocations;

public DeclarativeAiServiceBuildItem(
Expand All @@ -51,6 +53,8 @@ public DeclarativeAiServiceBuildItem(
DotName toolProviderClassDotName,
Optional<String> beanName,
DotName toolHallucinationStrategyClassDotName,
DeclarativeAiServiceInputGuardrails inputGuardrails,
DeclarativeAiServiceOutputGuardrails outputGuardrails,
Integer maxSequentialToolInvocations) {
this.serviceClassInfo = serviceClassInfo;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
Expand All @@ -69,6 +73,8 @@ public DeclarativeAiServiceBuildItem(
this.toolProviderClassDotName = toolProviderClassDotName;
this.beanName = beanName;
this.toolHallucinationStrategyClassDotName = toolHallucinationStrategyClassDotName;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
this.maxSequentialToolInvocations = maxSequentialToolInvocations;
}

Expand Down Expand Up @@ -140,6 +146,35 @@ public DotName getToolHallucinationStrategyClassDotName() {
return toolHallucinationStrategyClassDotName;
}

public DeclarativeAiServiceInputGuardrails getInputGuardrails() {
return inputGuardrails;
}

public DeclarativeAiServiceOutputGuardrails getOutputGuardrails() {
return outputGuardrails;
}

public record DeclarativeAiServiceInputGuardrails(List<ClassInfo> inputGuardrailClassInfos) {
public List<String> asClassNames() {
return this.inputGuardrailClassInfos.stream()
.map(classInfo -> classInfo.name().toString())
.toList();
}
}

public record DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries,
int actualMaxRetries) {
public DeclarativeAiServiceOutputGuardrails(List<ClassInfo> outputGuardrailClassInfos, int maxRetries) {
this(outputGuardrailClassInfos, maxRetries, maxRetries);
}

public List<String> asClassNames() {
return this.outputGuardrailClassInfos.stream()
.map(classInfo -> classInfo.name().toString())
.toList();
}
}

public Integer getMaxSequentialToolInvocations() {
return maxSequentialToolInvocations;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,64 @@
import org.jboss.logging.Logger;

import dev.langchain4j.data.message.UserMessage;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import dev.langchain4j.guardrail.InputGuardrail;
import dev.langchain4j.guardrail.InputGuardrailRequest;
import dev.langchain4j.guardrail.InputGuardrailResult;
import dev.langchain4j.guardrail.OutputGuardrail;
import dev.langchain4j.guardrail.OutputGuardrailRequest;
import dev.langchain4j.guardrail.OutputGuardrailResult;

final class GuardrailObservabilityProcessorSupport {
private static final Logger LOG = Logger.getLogger(GuardrailObservabilityProcessorSupport.class);
private static final DotName INPUT_GUARDRAIL_PARAMS = DotName.createSimple(InputGuardrailParams.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_INPUT_GUARDRAIL_PARAMS = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_INPUT_GUARDRAIL_RESULT = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrailResult.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_OUTPUT_GUARDRAIL_PARAMS = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_OUTPUT_GUARDRAIL_RESULT = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_INPUT_GUARDRAIL = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrail.class);

/**
* @deprecated These tests will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
private static final DotName QUARKUS_OUTPUT_GUARDRAIL = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrail.class);
private static final DotName INPUT_GUARDRAIL_REQUEST = DotName.createSimple(InputGuardrailRequest.class);
private static final DotName INPUT_GUARDRAIL_RESULT = DotName.createSimple(InputGuardrailResult.class);
private static final DotName OUTPUT_GUARDRAIL_PARAMS = DotName.createSimple(OutputGuardrailParams.class);
private static final DotName OUTPUT_GUARDRAIL_REQUEST = DotName.createSimple(OutputGuardrailRequest.class);
private static final DotName OUTPUT_GUARDRAIL_RESULT = DotName.createSimple(OutputGuardrailResult.class);
private static final DotName INPUT_GUARDRAIL = DotName.createSimple(InputGuardrail.class);
private static final DotName OUTPUT_GUARDRAIL = DotName.createSimple(OutputGuardrail.class);

static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed");
static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted");
static final DotName WITH_SPAN = DotName.createSimple("io.opentelemetry.instrumentation.annotations.WithSpan");
Expand All @@ -37,11 +80,17 @@ enum TransformType {
}

enum GuardrailType {
QUARKUS_INPUT,
QUARKUS_OUTPUT,
INPUT,
OUTPUT;

static Optional<GuardrailType> from(IndexView indexView, ClassInfo classToCheck) {
if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) {
if (indexView.getAllKnownImplementors(QUARKUS_INPUT_GUARDRAIL).contains(classToCheck)) {
return Optional.of(QUARKUS_INPUT);
} else if (indexView.getAllKnownImplementors(QUARKUS_OUTPUT_GUARDRAIL).contains(classToCheck)) {
return Optional.of(QUARKUS_OUTPUT);
} else if (indexView.getAllKnownImplementors(INPUT_GUARDRAIL).contains(classToCheck)) {
return Optional.of(INPUT);
} else if (indexView.getAllKnownImplementors(OUTPUT_GUARDRAIL).contains(classToCheck)) {
return Optional.of(OUTPUT);
Expand Down Expand Up @@ -102,16 +151,18 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI
}

var isOtherValidateMethodVariant = switch (guardrailType) {
case INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo);
case OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo);
case QUARKUS_INPUT, INPUT -> isInputGuardrailValidateMethodWithUserMessage(methodInfo);
case QUARKUS_OUTPUT, OUTPUT -> isOutputGuardrailValidateMethodWithAiMessage(methodInfo);
};

if (isOtherValidateMethodVariant && !doesMethodAlreadyHaveTransformationAnnotation(methodInfo, transformType)) {
// If this is the other method variant, we need to ensure that the
// variant with the params isn't also present on the method's declaring class
var paramType = switch (guardrailType) {
case INPUT -> Type.parse(INPUT_GUARDRAIL_PARAMS.toString());
case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_PARAMS.toString());
case QUARKUS_INPUT -> Type.parse(QUARKUS_INPUT_GUARDRAIL_PARAMS.toString());
case QUARKUS_OUTPUT -> Type.parse(QUARKUS_OUTPUT_GUARDRAIL_PARAMS.toString());
case INPUT -> Type.parse(INPUT_GUARDRAIL_REQUEST.toString());
case OUTPUT -> Type.parse(OUTPUT_GUARDRAIL_REQUEST.toString());
};

var otherValidateMethod = methodDeclaringClass.method("validate", paramType);
Expand All @@ -129,9 +180,16 @@ private static boolean shouldTransformGuardrailValidateMethod(MethodInfo methodI
* Checks the method meets <strong>ALL</strong> the following conditions:
* <ul>
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailParams} then the return type must be
* <li><strong>IF</strong> the method's single parameter's type is
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is
* {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must
* be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailRequest} then the return type must
* be
* {@link InputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailParams} then the return type must
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must
* be {@link OutputGuardrailResult}</li>
* </ul>
*/
Expand All @@ -143,7 +201,8 @@ static boolean isGuardrailValidateMethodWithParams(MethodInfo methodInfo) {
* Checks the method meets <strong>ALL</strong> the following conditions:
* <ul>
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
* <li>The method's return type is {@link InputGuardrailResult}</li>
* <li>The method's return type is {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult} or
* {@link InputGuardrailResult}</li>
* <li>The method's single parameter's type is {@link dev.langchain4j.data.message.UserMessage}</li>
* </ul>
*/
Expand All @@ -156,7 +215,8 @@ private static boolean isInputGuardrailValidateMethodWithUserMessage(MethodInfo
* Checks the method meets <strong>ALL</strong> the following conditions:
* <ul>
* <li>The method's name is {@link #VALIDATE_METHOD_NAME}</li>
* <li>The method's return type is {@link OutputGuardrailResult}</li>
* <li>The method's return type is {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult} or
* {@link OutputGuardrailResult}</li>
* <li>The method's single parameter's type is {@link dev.langchain4j.data.message.AiMessage}</li>
* </ul>
*/
Expand All @@ -168,9 +228,16 @@ private static boolean isOutputGuardrailValidateMethodWithAiMessage(MethodInfo m
/**
* Checks the method meets <strong>ALL</strong> the following conditions:
* <ul>
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailParams} then the return type must be
* <li><strong>IF</strong> the method's single parameter's type is
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailParams} then the return type must be
* {@link io.quarkiverse.langchain4j.guardrails.InputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is
* {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams} then the return type must
* be {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is {@link InputGuardrailRequest} then the return type must
* be
* {@link InputGuardrailResult}</li>
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailParams} then the return type must
* <li><strong>IF</strong> the method's single parameter's type is {@link OutputGuardrailRequest} then the return type must
* be {@link OutputGuardrailResult}</li>
* </ul>
*/
Expand All @@ -187,8 +254,13 @@ private static boolean doesValidateMethodWithParamsHaveCorrectSignature(MethodIn
// Also check the return type
var returnType = methodInfo.returnType().name();

return (INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) ||
(OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType));
return (QUARKUS_INPUT_GUARDRAIL_PARAMS.equals(paramTypeName) && QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType))
||
(QUARKUS_OUTPUT_GUARDRAIL_PARAMS.equals(paramTypeName)
&& QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType))
||
(INPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && INPUT_GUARDRAIL_RESULT.equals(returnType)) ||
(OUTPUT_GUARDRAIL_REQUEST.equals(paramTypeName) && OUTPUT_GUARDRAIL_RESULT.equals(returnType));
}

return false;
Expand All @@ -207,7 +279,8 @@ private static boolean doesValidateMethodWithoutParamsHaveCorrectSignature(Metho
var returnType = methodInfo.returnType().name();

return paramType.equals(paramTypeName) &&
(INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType));
(QUARKUS_INPUT_GUARDRAIL_RESULT.equals(returnType) || QUARKUS_OUTPUT_GUARDRAIL_RESULT.equals(returnType) ||
INPUT_GUARDRAIL_RESULT.equals(returnType) || OUTPUT_GUARDRAIL_RESULT.equals(returnType));
}

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.UserName;
import dev.langchain4j.service.guardrail.InputGuardrails;
import dev.langchain4j.service.guardrail.OutputGuardrails;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.web.search.WebSearchEngine;
import dev.langchain4j.web.search.WebSearchTool;
Expand All @@ -36,8 +38,6 @@
import io.quarkiverse.langchain4j.PdfUrl;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.SeedMemory;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;

public class LangChain4jDotNames {
Expand All @@ -49,6 +49,18 @@ public class LangChain4jDotNames {
public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class);
public static final DotName CHAT_MESSAGE = DotName.createSimple(ChatMessage.class);
public static final DotName TOKEN_STREAM = DotName.createSimple(TokenStream.class);
/**
* @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
public static final DotName QUARKUS_OUTPUT_GUARDRAILS = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.OutputGuardrails.class);
/**
* @deprecated Will go away once the Quarkus-specific guardrail implementation has been fully removed
*/
@Deprecated(forRemoval = true)
public static final DotName QUARKUS_INPUT_GUARDRAILS = DotName
.createSimple(io.quarkiverse.langchain4j.guardrails.InputGuardrails.class);
public static final DotName OUTPUT_GUARDRAILS = DotName.createSimple(OutputGuardrails.class);
public static final DotName INPUT_GUARDRAILS = DotName.createSimple(InputGuardrails.class);
static final DotName AI_SERVICES = DotName.createSimple(AiServices.class);
Expand Down
Loading
Loading