@@ -7,51 +7,40 @@ module Langchain::LLM
7
7
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
8
8
#
9
9
# Usage:
10
- # llm = Langchain::LLM::AwsBedrock.new(llm_options : {})
10
+ # llm = Langchain::LLM::AwsBedrock.new(default_options : {})
11
11
#
12
12
class AwsBedrock < Base
13
13
DEFAULTS = {
14
- chat_model : "anthropic.claude-v2 " ,
15
- completion_model : "anthropic.claude-v2" ,
14
+ chat_model : "anthropic.claude-3-5-sonnet-20240620-v1:0 " ,
15
+ completion_model : "anthropic.claude-v2:1 " ,
16
16
embedding_model : "amazon.titan-embed-text-v1" ,
17
17
max_tokens_to_sample : 300 ,
18
18
temperature : 1 ,
19
19
top_k : 250 ,
20
20
top_p : 0.999 ,
21
21
stop_sequences : [ "\n \n Human:" ] ,
22
- anthropic_version : "bedrock-2023-05-31" ,
23
- return_likelihoods : "NONE" ,
24
- count_penalty : {
25
- scale : 0 ,
26
- apply_to_whitespaces : false ,
27
- apply_to_punctuations : false ,
28
- apply_to_numbers : false ,
29
- apply_to_stopwords : false ,
30
- apply_to_emojis : false
31
- } ,
32
- presence_penalty : {
33
- scale : 0 ,
34
- apply_to_whitespaces : false ,
35
- apply_to_punctuations : false ,
36
- apply_to_numbers : false ,
37
- apply_to_stopwords : false ,
38
- apply_to_emojis : false
39
- } ,
40
- frequency_penalty : {
41
- scale : 0 ,
42
- apply_to_whitespaces : false ,
43
- apply_to_punctuations : false ,
44
- apply_to_numbers : false ,
45
- apply_to_stopwords : false ,
46
- apply_to_emojis : false
47
- }
22
+ return_likelihoods : "NONE"
48
23
} . freeze
49
24
50
25
attr_reader :client , :defaults
51
26
52
- SUPPORTED_COMPLETION_PROVIDERS = %i[ anthropic ai21 cohere meta ] . freeze
53
- SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[ anthropic ] . freeze
54
- SUPPORTED_EMBEDDING_PROVIDERS = %i[ amazon cohere ] . freeze
27
+ SUPPORTED_COMPLETION_PROVIDERS = %i[
28
+ anthropic
29
+ ai21
30
+ cohere
31
+ meta
32
+ ] . freeze
33
+
34
+ SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
35
+ anthropic
36
+ ai21
37
+ mistral
38
+ ] . freeze
39
+
40
+ SUPPORTED_EMBEDDING_PROVIDERS = %i[
41
+ amazon
42
+ cohere
43
+ ] . freeze
55
44
56
45
def initialize ( aws_client_options : { } , default_options : { } )
57
46
depends_on "aws-sdk-bedrockruntime" , req : "aws-sdk-bedrockruntime"
@@ -64,8 +53,7 @@ def initialize(aws_client_options: {}, default_options: {})
64
53
temperature : { } ,
65
54
max_tokens : { default : @defaults [ :max_tokens_to_sample ] } ,
66
55
metadata : { } ,
67
- system : { } ,
68
- anthropic_version : { default : "bedrock-2023-05-31" }
56
+ system : { }
69
57
)
70
58
chat_parameters . ignore ( :n , :user )
71
59
chat_parameters . remap ( stop : :stop_sequences )
@@ -100,23 +88,25 @@ def embed(text:, **params)
100
88
# @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
101
89
# @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
102
90
#
103
- def complete ( prompt :, **params )
104
- raise "Completion provider #{ completion_provider } is not supported." unless SUPPORTED_COMPLETION_PROVIDERS . include? ( completion_provider )
91
+ def complete (
92
+ prompt :,
93
+ model : @defaults [ :completion_model ] ,
94
+ **params
95
+ )
96
+ raise "Completion provider #{ model } is not supported." unless SUPPORTED_COMPLETION_PROVIDERS . include? ( provider_name ( model ) )
105
97
106
- raise "Model #{ @defaults [ :completion_model ] } only supports #chat." if @defaults [ :completion_model ] . include? ( "claude-3" )
107
-
108
- parameters = compose_parameters params
98
+ parameters = compose_parameters ( params , model )
109
99
110
100
parameters [ :prompt ] = wrap_prompt prompt
111
101
112
102
response = client . invoke_model ( {
113
- model_id : @defaults [ :completion_model ] ,
103
+ model_id : model ,
114
104
body : parameters . to_json ,
115
105
content_type : "application/json" ,
116
106
accept : "application/json"
117
107
} )
118
108
119
- parse_response response
109
+ parse_response ( response , model )
120
110
end
121
111
122
112
# Generate a chat completion for a given prompt
@@ -137,10 +127,11 @@ def complete(prompt:, **params)
137
127
# @return [Langchain::LLM::AnthropicResponse] Response object
138
128
def chat ( params = { } , &block )
139
129
parameters = chat_parameters . to_params ( params )
130
+ parameters = compose_parameters ( parameters , parameters [ :model ] )
140
131
141
- raise ArgumentError . new ( "messages argument is required" ) if Array ( parameters [ :messages ] ) . empty?
142
-
143
- raise "Model #{ parameters [ :model ] } does not support chat completions." unless Langchain :: LLM :: AwsBedrock :: SUPPORTED_CHAT_COMPLETION_PROVIDERS . include? ( completion_provider )
132
+ unless SUPPORTED_CHAT_COMPLETION_PROVIDERS . include? ( provider_name ( parameters [ :model ] ) )
133
+ raise "Chat provider #{ parameters [ :model ] } is not supported."
134
+ end
144
135
145
136
if block
146
137
response_chunks = [ ]
@@ -168,12 +159,26 @@ def chat(params = {}, &block)
168
159
accept : "application/json"
169
160
} )
170
161
171
- parse_response response
162
+ parse_response ( response , parameters [ :model ] )
172
163
end
173
164
end
174
165
175
166
private
176
167
168
+ def parse_model_id ( model_id )
169
+ model_id
170
+ . gsub ( "us." , "" ) # Meta append "us." to their model ids
171
+ . split ( "." )
172
+ end
173
+
174
+ def provider_name ( model_id )
175
+ parse_model_id ( model_id ) . first . to_sym
176
+ end
177
+
178
+ def model_name ( model_id )
179
+ parse_model_id ( model_id ) . last
180
+ end
181
+
177
182
def completion_provider
178
183
@defaults [ :completion_model ] . split ( "." ) . first . to_sym
179
184
end
@@ -200,15 +205,17 @@ def max_tokens_key
200
205
end
201
206
end
202
207
203
- def compose_parameters ( params )
204
- if completion_provider == :anthropic
205
- compose_parameters_anthropic params
206
- elsif completion_provider == :cohere
207
- compose_parameters_cohere params
208
- elsif completion_provider == :ai21
209
- compose_parameters_ai21 params
210
- elsif completion_provider == :meta
211
- compose_parameters_meta params
208
+ def compose_parameters ( params , model_id )
209
+ if provider_name ( model_id ) == :anthropic
210
+ compose_parameters_anthropic ( params )
211
+ elsif provider_name ( model_id ) == :cohere
212
+ compose_parameters_cohere ( params )
213
+ elsif provider_name ( model_id ) == :ai21
214
+ params
215
+ elsif provider_name ( model_id ) == :meta
216
+ params
217
+ elsif provider_name ( model_id ) == :mistral
218
+ params
212
219
end
213
220
end
214
221
@@ -220,15 +227,17 @@ def compose_embedding_parameters(params)
220
227
end
221
228
end
222
229
223
- def parse_response ( response )
224
- if completion_provider == :anthropic
230
+ def parse_response ( response , model_id )
231
+ if provider_name ( model_id ) == :anthropic
225
232
Langchain ::LLM ::AnthropicResponse . new ( JSON . parse ( response . body . string ) )
226
- elsif completion_provider == :cohere
233
+ elsif provider_name ( model_id ) == :cohere
227
234
Langchain ::LLM ::CohereResponse . new ( JSON . parse ( response . body . string ) )
228
- elsif completion_provider == :ai21
235
+ elsif provider_name ( model_id ) == :ai21
229
236
Langchain ::LLM ::AI21Response . new ( JSON . parse ( response . body . string , symbolize_names : true ) )
230
- elsif completion_provider == :meta
237
+ elsif provider_name ( model_id ) == :meta
231
238
Langchain ::LLM ::AwsBedrockMetaResponse . new ( JSON . parse ( response . body . string ) )
239
+ elsif provider_name ( model_id ) == :mistral
240
+ Langchain ::LLM ::MistralAIResponse . new ( JSON . parse ( response . body . string ) )
232
241
end
233
242
end
234
243
@@ -276,61 +285,7 @@ def compose_parameters_cohere(params)
276
285
end
277
286
278
287
def compose_parameters_anthropic ( params )
279
- default_params = @defaults . merge ( params )
280
-
281
- {
282
- max_tokens_to_sample : default_params [ :max_tokens_to_sample ] ,
283
- temperature : default_params [ :temperature ] ,
284
- top_k : default_params [ :top_k ] ,
285
- top_p : default_params [ :top_p ] ,
286
- stop_sequences : default_params [ :stop_sequences ] ,
287
- anthropic_version : default_params [ :anthropic_version ]
288
- }
289
- end
290
-
291
- def compose_parameters_ai21 ( params )
292
- default_params = @defaults . merge ( params )
293
-
294
- {
295
- maxTokens : default_params [ :max_tokens_to_sample ] ,
296
- temperature : default_params [ :temperature ] ,
297
- topP : default_params [ :top_p ] ,
298
- stopSequences : default_params [ :stop_sequences ] ,
299
- countPenalty : {
300
- scale : default_params [ :count_penalty ] [ :scale ] ,
301
- applyToWhitespaces : default_params [ :count_penalty ] [ :apply_to_whitespaces ] ,
302
- applyToPunctuations : default_params [ :count_penalty ] [ :apply_to_punctuations ] ,
303
- applyToNumbers : default_params [ :count_penalty ] [ :apply_to_numbers ] ,
304
- applyToStopwords : default_params [ :count_penalty ] [ :apply_to_stopwords ] ,
305
- applyToEmojis : default_params [ :count_penalty ] [ :apply_to_emojis ]
306
- } ,
307
- presencePenalty : {
308
- scale : default_params [ :presence_penalty ] [ :scale ] ,
309
- applyToWhitespaces : default_params [ :presence_penalty ] [ :apply_to_whitespaces ] ,
310
- applyToPunctuations : default_params [ :presence_penalty ] [ :apply_to_punctuations ] ,
311
- applyToNumbers : default_params [ :presence_penalty ] [ :apply_to_numbers ] ,
312
- applyToStopwords : default_params [ :presence_penalty ] [ :apply_to_stopwords ] ,
313
- applyToEmojis : default_params [ :presence_penalty ] [ :apply_to_emojis ]
314
- } ,
315
- frequencyPenalty : {
316
- scale : default_params [ :frequency_penalty ] [ :scale ] ,
317
- applyToWhitespaces : default_params [ :frequency_penalty ] [ :apply_to_whitespaces ] ,
318
- applyToPunctuations : default_params [ :frequency_penalty ] [ :apply_to_punctuations ] ,
319
- applyToNumbers : default_params [ :frequency_penalty ] [ :apply_to_numbers ] ,
320
- applyToStopwords : default_params [ :frequency_penalty ] [ :apply_to_stopwords ] ,
321
- applyToEmojis : default_params [ :frequency_penalty ] [ :apply_to_emojis ]
322
- }
323
- }
324
- end
325
-
326
- def compose_parameters_meta ( params )
327
- default_params = @defaults . merge ( params )
328
-
329
- {
330
- temperature : default_params [ :temperature ] ,
331
- top_p : default_params [ :top_p ] ,
332
- max_gen_len : default_params [ :max_tokens_to_sample ]
333
- }
288
+ params . merge ( anthropic_version : "bedrock-2023-05-31" )
334
289
end
335
290
336
291
def response_from_chunks ( chunks )
0 commit comments