Skip to content

Commit bb7824c

Browse files
committed
feat: add a new genai locator module
1 parent d93426c commit bb7824c

File tree

14 files changed

+586
-0
lines changed

14 files changed

+586
-0
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ buildscript {
22
ext {
33
jmockitVersion = "1.49"
44
springBootVersion = "3.4.3"
5+
springAiVersion = "1.0.0"
56
}
67
}
78

java-cfenv-boot-tanzu-genai/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Tanzu GenAI support
2+
3+
This library is for use when accessing a Tanzu GenAI tile (version >= 10.2) configured plan with CF. This library uses the `VCAP_SERVICES` environment data to set properties that will enable a GenAILocator.
4+
5+
## Spring Applications
6+
7+
Spring Applications can use this library to auto-configure a GenAILocator that can be used to determine which models/mcp servers are available, what capabilities they support and a method of accessing them.
8+
9+
This service provides the following properties to your spring application:
10+
11+
| Property Name | Value |
12+
|--------------------------|---------------------------------|
13+
| genai.locator.config-url | config_url (from VCAP_SERVICES) |
14+
| genai.locator.api-base | api_base (from VCAP_SERVICES) |
15+
| genai.locator.api-key | api_key (from VCAP_SERVICES) |
16+
17+
Please see the Sample Apps below for more information.
18+
19+
### Sample Apps
20+
21+
Sample apps using this library are available at TODO.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
plugins {
2+
id 'io.pivotal.cfenv.java-conventions'
3+
}
4+
5+
description = 'Java CF Env Tanzu GenAI'
6+
7+
dependencies {
8+
api project(':java-cfenv-boot')
9+
api "org.springframework.ai:spring-ai-openai:${springAiVersion}"
10+
implementation("org.springframework.boot:spring-boot-autoconfigure:${springBootVersion}")
11+
12+
testImplementation project(':java-cfenv-test-support')
13+
testImplementation "junit:junit"
14+
testImplementation "org.jmockit:jmockit:${jmockitVersion}"
15+
16+
testRuntimeOnly('org.junit.vintage:junit-vintage-engine') {
17+
exclude group: 'org.hamcrest', module: 'hamcrest-core'
18+
}
19+
}
20+
21+
tasks.named('jar') {
22+
manifest {
23+
attributes 'Automatic-Module-Name': 'io.pivotal.cfenv.boot.genai'
24+
}
25+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package io.pivotal.cfenv.boot.genai;
2+
3+
import io.pivotal.cfenv.core.CfCredentials;
4+
import io.pivotal.cfenv.core.CfService;
5+
import io.pivotal.cfenv.spring.boot.CfEnvProcessor;
6+
import io.pivotal.cfenv.spring.boot.CfEnvProcessorProperties;
7+
8+
import java.util.Map;
9+
10+
public class CfGenaiProcessor implements CfEnvProcessor {
11+
12+
@Override
13+
public boolean accept(CfService service) {
14+
boolean isGenAIService = service.existsByTagIgnoreCase("genai") || service.existsByLabelStartsWith("genai");
15+
// we only want to process service instances that are generated from Tanzu Platform 10.2 or later
16+
return (isGenAIService && service.getCredentials().getMap().containsKey("endpoint"));
17+
}
18+
19+
@Override
20+
public void process(CfCredentials cfCredentials, Map<String, Object> properties) {
21+
Map<String, Object> endpoint = (Map<String, Object>)cfCredentials.getMap().get("endpoint");
22+
23+
properties.put("genai.locator.config-url", endpoint.get("config_url"));
24+
properties.put("genai.locator.api-key", endpoint.get("api_key"));
25+
properties.put("genai.locator.api-base", endpoint.get( "api_base"));
26+
}
27+
28+
@Override
29+
public CfEnvProcessorProperties getProperties() {
30+
return CfEnvProcessorProperties.builder()
31+
.propertyPrefixes("genai.locator")
32+
.serviceName("Tanzu GenAI Locator").build();
33+
}
34+
}
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
package io.pivotal.cfenv.boot.genai;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import org.slf4j.Logger;
6+
import org.slf4j.LoggerFactory;
7+
import org.springframework.ai.chat.model.ChatModel;
8+
import org.springframework.ai.document.MetadataMode;
9+
import org.springframework.ai.embedding.EmbeddingModel;
10+
import org.springframework.ai.openai.OpenAiChatModel;
11+
import org.springframework.ai.openai.OpenAiChatOptions;
12+
import org.springframework.ai.openai.OpenAiEmbeddingModel;
13+
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
14+
import org.springframework.ai.openai.api.OpenAiApi;
15+
import org.springframework.web.client.RestClient;
16+
17+
import java.util.AbstractMap;
18+
import java.util.List;
19+
import java.util.Map;
20+
import java.util.function.Predicate;
21+
22+
public class DefaultGenAILocator implements GenAILocator {
23+
24+
private static final Logger LOGGER = LoggerFactory.getLogger(DefaultGenAILocator.class);
25+
26+
private final String configUrl;
27+
private final String apiKey;
28+
private final String apiBase;
29+
30+
public DefaultGenAILocator(String configUrl, String apiKey, String apiBase) {
31+
this.configUrl = configUrl;
32+
this.apiKey = apiKey;
33+
this.apiBase = apiBase;
34+
}
35+
36+
@Override
37+
public List<String> getModelNames() {
38+
return getModelNamesByCapability(null);
39+
}
40+
41+
@Override
42+
public List<String> getModelNamesByCapability(String capability) {
43+
return getModelNamesByCapabilityAndLabels(capability, Map.of());
44+
}
45+
46+
@Override
47+
public List<String> getModelNamesByLabels(Map<String, String> labels) {
48+
return getModelNamesByCapabilityAndLabels(null, labels);
49+
}
50+
51+
@Override
52+
public List<String> getModelNamesByCapabilityAndLabels(
53+
String capability, Map<String, String> labels) {
54+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
55+
56+
return models.stream()
57+
.filter(filterModelConnectivityOnCapability(capability))
58+
.filter(filterModelConnectivityOnLabels(labels))
59+
.map(a -> a.name)
60+
.toList();
61+
}
62+
63+
@Override
64+
public ChatModel getChatModelByName(String name) {
65+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
66+
67+
return models.stream()
68+
.filter(filterModelConnectivityOnCapability("CHAT"))
69+
.filter(c -> c.name().equals(name))
70+
.map(DefaultGenAILocator::createChatModel)
71+
.findFirst()
72+
.orElseThrow(
73+
() -> new RuntimeException("Unable to find chat model with name '" + name + "'"));
74+
}
75+
76+
@Override
77+
public List<ChatModel> getChatModelsByLabels(Map<String, String> labels) {
78+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
79+
80+
return models.stream()
81+
.filter(filterModelConnectivityOnCapability("CHAT"))
82+
.filter(filterModelConnectivityOnLabels(labels))
83+
.map(DefaultGenAILocator::createChatModel)
84+
.toList();
85+
}
86+
87+
@Override
88+
public ChatModel getFirstAvailableChatModel() {
89+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
90+
91+
return models.stream()
92+
.filter(filterModelConnectivityOnCapability("CHAT"))
93+
.map(DefaultGenAILocator::createChatModel)
94+
.findFirst()
95+
.orElseThrow(() -> new RuntimeException("Unable to find first chat model"));
96+
}
97+
98+
@Override
99+
public ChatModel getFirstAvailableChatModelByLabels(Map<String, String> labels) {
100+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
101+
102+
return models.stream()
103+
.filter(filterModelConnectivityOnCapability("CHAT"))
104+
.filter(filterModelConnectivityOnLabels(labels))
105+
.map(DefaultGenAILocator::createChatModel)
106+
.findFirst()
107+
.orElseThrow(() -> new RuntimeException("Unable to find first chat model"));
108+
}
109+
110+
@Override
111+
public List<ChatModel> getToolModelsByLabels(Map<String, String> labels) {
112+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
113+
114+
return models.stream()
115+
.filter(filterModelConnectivityOnCapability("TOOLS"))
116+
.filter(filterModelConnectivityOnLabels(labels))
117+
.map(DefaultGenAILocator::createChatModel)
118+
.toList();
119+
}
120+
121+
@Override
122+
public ChatModel getFirstAvailableToolModel() {
123+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
124+
125+
return models.stream()
126+
.filter(filterModelConnectivityOnCapability("TOOLS"))
127+
.map(DefaultGenAILocator::createChatModel)
128+
.findFirst()
129+
.orElseThrow(() -> new RuntimeException("Unable to find first tool model"));
130+
}
131+
132+
@Override
133+
public ChatModel getFirstAvailableToolModelByLabels(Map<String, String> labels) {
134+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
135+
136+
return models.stream()
137+
.filter(filterModelConnectivityOnCapability("TOOLS"))
138+
.filter(filterModelConnectivityOnLabels(labels))
139+
.map(DefaultGenAILocator::createChatModel)
140+
.findFirst()
141+
.orElseThrow(() -> new RuntimeException("Unable to find first tool model"));
142+
}
143+
144+
private static ChatModel createChatModel(ModelConnectivity c) {
145+
OpenAiApi api = OpenAiApi.builder().apiKey(c.apiKey()).baseUrl(c.apiBase()).build();
146+
return OpenAiChatModel.builder()
147+
.defaultOptions(OpenAiChatOptions.builder().model(c.name()).build())
148+
.openAiApi(api)
149+
.build();
150+
}
151+
152+
@Override
153+
public EmbeddingModel getEmbeddingModelByName(String name) {
154+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
155+
156+
return models.stream()
157+
.filter(filterModelConnectivityOnCapability("EMBEDDING"))
158+
.filter(c -> c.name().equals(name))
159+
.map(DefaultGenAILocator::createEmbeddingModel)
160+
.findFirst()
161+
.orElseThrow(
162+
() -> new RuntimeException("Unable to find embedding model with name '" + name + "'"));
163+
}
164+
165+
@Override
166+
public List<EmbeddingModel> getEmbeddingModelsByLabels(Map<String, String> labels) {
167+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
168+
169+
return models.stream()
170+
.filter(filterModelConnectivityOnCapability("EMBEDDING"))
171+
.filter(filterModelConnectivityOnLabels(labels))
172+
.map(DefaultGenAILocator::createEmbeddingModel)
173+
.toList();
174+
}
175+
176+
@Override
177+
public EmbeddingModel getFirstAvailableEmbeddingModel() {
178+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
179+
180+
return models.stream()
181+
.filter(filterModelConnectivityOnCapability("EMBEDDING"))
182+
.map(DefaultGenAILocator::createEmbeddingModel)
183+
.findFirst()
184+
.orElseThrow(() -> new RuntimeException("Unable to find first embedding model"));
185+
}
186+
187+
@Override
188+
public EmbeddingModel getFirstAvailableEmbeddingModelByLabels(Map<String, String> labels) {
189+
List<ModelConnectivity> models = getAllModelConnectivityDetails();
190+
191+
return models.stream()
192+
.filter(filterModelConnectivityOnCapability("EMBEDDING"))
193+
.filter(filterModelConnectivityOnLabels(labels))
194+
.map(DefaultGenAILocator::createEmbeddingModel)
195+
.findFirst()
196+
.orElseThrow(() -> new RuntimeException("Unable to find first embedding model"));
197+
}
198+
199+
private static EmbeddingModel createEmbeddingModel(ModelConnectivity c) {
200+
OpenAiApi api = OpenAiApi.builder().apiKey(c.apiKey()).baseUrl(c.apiBase()).build();
201+
return new OpenAiEmbeddingModel(
202+
api, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder().model(c.name()).build());
203+
}
204+
205+
@Override
206+
public List<McpConnectivity> getMcpServers() {
207+
return getAllMcpConnectivityDetails();
208+
}
209+
210+
private List<ModelConnectivity> getAllModelConnectivityDetails() {
211+
ConfigEndpoint e = getEndpointConfig();
212+
return e.advertisedModels
213+
.stream()
214+
.map( a ->
215+
new ModelConnectivity(
216+
a.name(),
217+
a.capabilities(),
218+
a.labels(),
219+
apiKey,
220+
apiBase + e.wireFormat().toLowerCase())
221+
)
222+
.toList();
223+
}
224+
225+
private List<McpConnectivity> getAllMcpConnectivityDetails() {
226+
return getEndpointConfig()
227+
.advertisedMcpServers
228+
.stream()
229+
.map(m -> new McpConnectivity(m.url()))
230+
.toList();
231+
}
232+
233+
private ConfigEndpoint getEndpointConfig() {
234+
RestClient client = RestClient.builder().build();
235+
return client
236+
.get()
237+
.uri(configUrl)
238+
.header("Authorization", "Bearer " + apiKey)
239+
.retrieve()
240+
.body(ConfigEndpoint.class);
241+
}
242+
243+
private Predicate<ModelConnectivity> filterModelConnectivityOnLabels(Map<String, String> labels) {
244+
return modelConnectivity -> {
245+
if (labels == null || labels.isEmpty()) {
246+
return true;
247+
}
248+
249+
if (modelConnectivity.labels() == null) {
250+
return false;
251+
}
252+
253+
return modelConnectivity.labels().entrySet().containsAll(labels.entrySet());
254+
};
255+
}
256+
257+
private Predicate<ModelConnectivity> filterModelConnectivityOnCapability(String capability) {
258+
return modelConnectivity -> {
259+
if (capability == null || capability.isEmpty()) {
260+
return true;
261+
}
262+
263+
if (modelConnectivity.capabilities() == null) {
264+
return false;
265+
}
266+
267+
return modelConnectivity.capabilities().contains(capability);
268+
};
269+
}
270+
271+
@JsonIgnoreProperties(ignoreUnknown = true)
272+
private record ModelConnectivity(
273+
String name,
274+
List<String> capabilities,
275+
Map<String, String> labels,
276+
String apiKey,
277+
String apiBase) {}
278+
279+
@JsonIgnoreProperties(ignoreUnknown = true)
280+
private record ConfigEndpoint(
281+
@JsonProperty("name") String name,
282+
@JsonProperty("description") String description,
283+
@JsonProperty("wireFormat") String wireFormat,
284+
@JsonProperty("advertisedModels") List<ConfigAdvertisedModel> advertisedModels,
285+
@JsonProperty("advertisedMcpServers") List<ConfigAdvertisedMcpServer> advertisedMcpServers) {}
286+
287+
@JsonIgnoreProperties(ignoreUnknown = true)
288+
private record ConfigAdvertisedModel(
289+
@JsonProperty("name") String name,
290+
@JsonProperty("description") String description,
291+
@JsonProperty("capabilities") List<String> capabilities,
292+
@JsonProperty("labels") Map<String, String> labels) {}
293+
294+
@JsonIgnoreProperties(ignoreUnknown = true)
295+
private record ConfigAdvertisedMcpServer(@JsonProperty("url") String url) {}
296+
}

0 commit comments

Comments
 (0)