Skip to content

Commit 683a85b

Browse files
Langchain::Assistant works with AWS Bedrock-hosted Anthropic models (#849)
* Langchain::Assistant works with AWS Bedrock-hosted Anthropic models * specs * Update adapter.rb * Fixes * changelog entry
1 parent 7baf643 commit 683a85b

File tree

9 files changed

+164
-235
lines changed

9 files changed

+164
-235
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
- [BREAKING]: A breaking change. After an upgrade, your app may need modifications to keep working correctly.
55
- [FEATURE]: A non-breaking improvement to the app. Either introduces new functionality, or improves on an existing feature.
66
- [BUGFIX]: Fixes a bug with a non-breaking change.
7-
- [COMPAT]: Compatibility improvements - changes to make Administrate more compatible with different dependency versions.
7+
- [COMPAT]: Compatibility improvements - changes to make Langchain.rb more compatible with different dependency versions.
88
- [OPTIM]: Optimization or performance increase.
99
- [DOCS]: Documentation changes. No changes to the library's behavior.
1010
- [SECURITY]: A change which fixes a security vulnerability.
1111

1212
## [Unreleased]
1313
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/858] Assistant, when using Anthropic, now also accepts image_url in the message.
1414
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/861] Clean up passing `max_tokens` to Anthropic constructor and chat method
15+
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/849] Langchain::Assistant now works with AWS Bedrock-hosted Anthropic models
16+
- [OPTIM] [https://github.com/patterns-ai-core/langchainrb/pull/849] Simplify Langchain::LLM::AwsBedrock class
1517

1618
## [0.19.0] - 2024-10-23
1719
- [BREAKING] [https://github.com/patterns-ai-core/langchainrb/pull/840] Rename `chat_completion_model_name` parameter to `chat_model` in Langchain::LLM parameters.

Gemfile.lock

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ GEM
4848
faraday-multipart (>= 1)
4949
ast (2.4.2)
5050
aws-eventstream (1.3.0)
51-
aws-partitions (1.937.0)
52-
aws-sdk-bedrockruntime (1.9.0)
53-
aws-sdk-core (~> 3, >= 3.193.0)
54-
aws-sigv4 (~> 1.1)
55-
aws-sdk-core (3.196.1)
51+
aws-partitions (1.992.0)
52+
aws-sdk-bedrockruntime (1.27.0)
53+
aws-sdk-core (~> 3, >= 3.210.0)
54+
aws-sigv4 (~> 1.5)
55+
aws-sdk-core (3.210.0)
5656
aws-eventstream (~> 1, >= 1.3.0)
57-
aws-partitions (~> 1, >= 1.651.0)
58-
aws-sigv4 (~> 1.8)
57+
aws-partitions (~> 1, >= 1.992.0)
58+
aws-sigv4 (~> 1.9)
5959
jmespath (~> 1, >= 1.6.1)
60-
aws-sigv4 (1.8.0)
60+
aws-sigv4 (1.10.0)
6161
aws-eventstream (~> 1, >= 1.0.2)
6262
baran (0.1.12)
6363
base64 (0.2.0)

lib/langchain/assistant/llm/adapter.rb

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ module LLM
66
# TODO: Fix the message truncation when context window is exceeded
77
class Adapter
88
def self.build(llm)
9-
case llm
10-
when Langchain::LLM::Anthropic
9+
if llm.is_a?(Langchain::LLM::Anthropic)
1110
LLM::Adapters::Anthropic.new
12-
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
11+
elsif llm.is_a?(Langchain::LLM::AwsBedrock) && llm.defaults[:chat_model].include?("anthropic")
12+
LLM::Adapters::AwsBedrockAnthropic.new
13+
elsif llm.is_a?(Langchain::LLM::GoogleGemini) || llm.is_a?(Langchain::LLM::GoogleVertexAI)
1314
LLM::Adapters::GoogleGemini.new
14-
when Langchain::LLM::MistralAI
15+
elsif llm.is_a?(Langchain::LLM::MistralAI)
1516
LLM::Adapters::MistralAI.new
16-
when Langchain::LLM::Ollama
17+
elsif llm.is_a?(Langchain::LLM::Ollama)
1718
LLM::Adapters::Ollama.new
18-
when Langchain::LLM::OpenAI
19+
elsif llm.is_a?(Langchain::LLM::OpenAI)
1920
LLM::Adapters::OpenAI.new
2021
else
2122
raise ArgumentError, "Unsupported LLM type: #{llm.class}"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# frozen_string_literal: true
2+
3+
module Langchain
4+
class Assistant
5+
module LLM
6+
module Adapters
7+
class AwsBedrockAnthropic < Anthropic
8+
private
9+
10+
# @param [String] choice
11+
# @param [Boolean] _parallel_tool_calls
12+
# @return [Hash]
13+
def build_tool_choice(choice, _parallel_tool_calls)
14+
# Aws Bedrock hosted Anthropic does not support parallel tool calls
15+
Langchain.logger.warn "WARNING: parallel_tool_calls is not supported by AWS Bedrock Anthropic currently"
16+
17+
tool_choice_object = {}
18+
19+
case choice
20+
when "auto"
21+
tool_choice_object[:type] = "auto"
22+
when "any"
23+
tool_choice_object[:type] = "any"
24+
else
25+
tool_choice_object[:type] = "tool"
26+
tool_choice_object[:name] = choice
27+
end
28+
29+
tool_choice_object
30+
end
31+
end
32+
end
33+
end
34+
end
35+
end

lib/langchain/llm/aws_bedrock.rb

Lines changed: 69 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -7,51 +7,40 @@ module Langchain::LLM
77
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
88
#
99
# Usage:
10-
# llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
10+
# llm = Langchain::LLM::AwsBedrock.new(default_options: {})
1111
#
1212
class AwsBedrock < Base
1313
DEFAULTS = {
14-
chat_model: "anthropic.claude-v2",
15-
completion_model: "anthropic.claude-v2",
14+
chat_model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
15+
completion_model: "anthropic.claude-v2:1",
1616
embedding_model: "amazon.titan-embed-text-v1",
1717
max_tokens_to_sample: 300,
1818
temperature: 1,
1919
top_k: 250,
2020
top_p: 0.999,
2121
stop_sequences: ["\n\nHuman:"],
22-
anthropic_version: "bedrock-2023-05-31",
23-
return_likelihoods: "NONE",
24-
count_penalty: {
25-
scale: 0,
26-
apply_to_whitespaces: false,
27-
apply_to_punctuations: false,
28-
apply_to_numbers: false,
29-
apply_to_stopwords: false,
30-
apply_to_emojis: false
31-
},
32-
presence_penalty: {
33-
scale: 0,
34-
apply_to_whitespaces: false,
35-
apply_to_punctuations: false,
36-
apply_to_numbers: false,
37-
apply_to_stopwords: false,
38-
apply_to_emojis: false
39-
},
40-
frequency_penalty: {
41-
scale: 0,
42-
apply_to_whitespaces: false,
43-
apply_to_punctuations: false,
44-
apply_to_numbers: false,
45-
apply_to_stopwords: false,
46-
apply_to_emojis: false
47-
}
22+
return_likelihoods: "NONE"
4823
}.freeze
4924

5025
attr_reader :client, :defaults
5126

52-
SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze
53-
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
54-
SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
27+
SUPPORTED_COMPLETION_PROVIDERS = %i[
28+
anthropic
29+
ai21
30+
cohere
31+
meta
32+
].freeze
33+
34+
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
35+
anthropic
36+
ai21
37+
mistral
38+
].freeze
39+
40+
SUPPORTED_EMBEDDING_PROVIDERS = %i[
41+
amazon
42+
cohere
43+
].freeze
5544

5645
def initialize(aws_client_options: {}, default_options: {})
5746
depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
@@ -64,8 +53,7 @@ def initialize(aws_client_options: {}, default_options: {})
6453
temperature: {},
6554
max_tokens: {default: @defaults[:max_tokens_to_sample]},
6655
metadata: {},
67-
system: {},
68-
anthropic_version: {default: "bedrock-2023-05-31"}
56+
system: {}
6957
)
7058
chat_parameters.ignore(:n, :user)
7159
chat_parameters.remap(stop: :stop_sequences)
@@ -100,23 +88,25 @@ def embed(text:, **params)
10088
# @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
10189
# @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
10290
#
103-
def complete(prompt:, **params)
104-
raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
91+
def complete(
92+
prompt:,
93+
model: @defaults[:completion_model],
94+
**params
95+
)
96+
raise "Completion provider #{model} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(provider_name(model))
10597

106-
raise "Model #{@defaults[:completion_model]} only supports #chat." if @defaults[:completion_model].include?("claude-3")
107-
108-
parameters = compose_parameters params
98+
parameters = compose_parameters(params, model)
10999

110100
parameters[:prompt] = wrap_prompt prompt
111101

112102
response = client.invoke_model({
113-
model_id: @defaults[:completion_model],
103+
model_id: model,
114104
body: parameters.to_json,
115105
content_type: "application/json",
116106
accept: "application/json"
117107
})
118108

119-
parse_response response
109+
parse_response(response, model)
120110
end
121111

122112
# Generate a chat completion for a given prompt
@@ -137,10 +127,11 @@ def complete(prompt:, **params)
137127
# @return [Langchain::LLM::AnthropicResponse] Response object
138128
def chat(params = {}, &block)
139129
parameters = chat_parameters.to_params(params)
130+
parameters = compose_parameters(parameters, parameters[:model])
140131

141-
raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?
142-
143-
raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
132+
unless SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(provider_name(parameters[:model]))
133+
raise "Chat provider #{parameters[:model]} is not supported."
134+
end
144135

145136
if block
146137
response_chunks = []
@@ -168,12 +159,26 @@ def chat(params = {}, &block)
168159
accept: "application/json"
169160
})
170161

171-
parse_response response
162+
parse_response(response, parameters[:model])
172163
end
173164
end
174165

175166
private
176167

168+
def parse_model_id(model_id)
169+
model_id
170+
.gsub("us.", "") # Meta append "us." to their model ids
171+
.split(".")
172+
end
173+
174+
def provider_name(model_id)
175+
parse_model_id(model_id).first.to_sym
176+
end
177+
178+
def model_name(model_id)
179+
parse_model_id(model_id).last
180+
end
181+
177182
def completion_provider
178183
@defaults[:completion_model].split(".").first.to_sym
179184
end
@@ -200,15 +205,17 @@ def max_tokens_key
200205
end
201206
end
202207

203-
def compose_parameters(params)
204-
if completion_provider == :anthropic
205-
compose_parameters_anthropic params
206-
elsif completion_provider == :cohere
207-
compose_parameters_cohere params
208-
elsif completion_provider == :ai21
209-
compose_parameters_ai21 params
210-
elsif completion_provider == :meta
211-
compose_parameters_meta params
208+
def compose_parameters(params, model_id)
209+
if provider_name(model_id) == :anthropic
210+
compose_parameters_anthropic(params)
211+
elsif provider_name(model_id) == :cohere
212+
compose_parameters_cohere(params)
213+
elsif provider_name(model_id) == :ai21
214+
params
215+
elsif provider_name(model_id) == :meta
216+
params
217+
elsif provider_name(model_id) == :mistral
218+
params
212219
end
213220
end
214221

@@ -220,15 +227,17 @@ def compose_embedding_parameters(params)
220227
end
221228
end
222229

223-
def parse_response(response)
224-
if completion_provider == :anthropic
230+
def parse_response(response, model_id)
231+
if provider_name(model_id) == :anthropic
225232
Langchain::LLM::AnthropicResponse.new(JSON.parse(response.body.string))
226-
elsif completion_provider == :cohere
233+
elsif provider_name(model_id) == :cohere
227234
Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
228-
elsif completion_provider == :ai21
235+
elsif provider_name(model_id) == :ai21
229236
Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
230-
elsif completion_provider == :meta
237+
elsif provider_name(model_id) == :meta
231238
Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
239+
elsif provider_name(model_id) == :mistral
240+
Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
232241
end
233242
end
234243

@@ -276,61 +285,7 @@ def compose_parameters_cohere(params)
276285
end
277286

278287
def compose_parameters_anthropic(params)
279-
default_params = @defaults.merge(params)
280-
281-
{
282-
max_tokens_to_sample: default_params[:max_tokens_to_sample],
283-
temperature: default_params[:temperature],
284-
top_k: default_params[:top_k],
285-
top_p: default_params[:top_p],
286-
stop_sequences: default_params[:stop_sequences],
287-
anthropic_version: default_params[:anthropic_version]
288-
}
289-
end
290-
291-
def compose_parameters_ai21(params)
292-
default_params = @defaults.merge(params)
293-
294-
{
295-
maxTokens: default_params[:max_tokens_to_sample],
296-
temperature: default_params[:temperature],
297-
topP: default_params[:top_p],
298-
stopSequences: default_params[:stop_sequences],
299-
countPenalty: {
300-
scale: default_params[:count_penalty][:scale],
301-
applyToWhitespaces: default_params[:count_penalty][:apply_to_whitespaces],
302-
applyToPunctuations: default_params[:count_penalty][:apply_to_punctuations],
303-
applyToNumbers: default_params[:count_penalty][:apply_to_numbers],
304-
applyToStopwords: default_params[:count_penalty][:apply_to_stopwords],
305-
applyToEmojis: default_params[:count_penalty][:apply_to_emojis]
306-
},
307-
presencePenalty: {
308-
scale: default_params[:presence_penalty][:scale],
309-
applyToWhitespaces: default_params[:presence_penalty][:apply_to_whitespaces],
310-
applyToPunctuations: default_params[:presence_penalty][:apply_to_punctuations],
311-
applyToNumbers: default_params[:presence_penalty][:apply_to_numbers],
312-
applyToStopwords: default_params[:presence_penalty][:apply_to_stopwords],
313-
applyToEmojis: default_params[:presence_penalty][:apply_to_emojis]
314-
},
315-
frequencyPenalty: {
316-
scale: default_params[:frequency_penalty][:scale],
317-
applyToWhitespaces: default_params[:frequency_penalty][:apply_to_whitespaces],
318-
applyToPunctuations: default_params[:frequency_penalty][:apply_to_punctuations],
319-
applyToNumbers: default_params[:frequency_penalty][:apply_to_numbers],
320-
applyToStopwords: default_params[:frequency_penalty][:apply_to_stopwords],
321-
applyToEmojis: default_params[:frequency_penalty][:apply_to_emojis]
322-
}
323-
}
324-
end
325-
326-
def compose_parameters_meta(params)
327-
default_params = @defaults.merge(params)
328-
329-
{
330-
temperature: default_params[:temperature],
331-
top_p: default_params[:top_p],
332-
max_gen_len: default_params[:max_tokens_to_sample]
333-
}
288+
params.merge(anthropic_version: "bedrock-2023-05-31")
334289
end
335290

336291
def response_from_chunks(chunks)

0 commit comments

Comments
 (0)