@@ -3,11 +3,9 @@ package me.kpavlov.langchain4j.kotlin.service
33import dev.langchain4j.internal.Exceptions
44import dev.langchain4j.model.input.structured.StructuredPrompt
55import dev.langchain4j.model.input.structured.StructuredPromptProcessor
6- import dev.langchain4j.service.IllegalConfigurationException
6+ import dev.langchain4j.service.InternalReflectionVariableResolver
77import dev.langchain4j.service.MemoryId
8- import dev.langchain4j.service.UserMessage
98import dev.langchain4j.service.UserName
10- import dev.langchain4j.service.V
119import me.kpavlov.langchain4j.kotlin.ChatMemoryId
1210import java.lang.reflect.Method
1311import java.lang.reflect.Parameter
@@ -23,95 +21,24 @@ import java.util.Optional
2321 *
2422 * @see https://github.com/langchain4j/langchain4j/pull/2951
2523 */
26- @Suppress(" detekt:all" )
2724internal object ReflectionVariableResolver {
2825 public fun findTemplateVariables (
2926 template : String ,
3027 method : Method ,
3128 args : Array <Any ?>? ,
32- ): MutableMap <String ?, Any ?> {
33- if (args == null ) {
34- return mutableMapOf<String ?, Any ?>()
35- }
36- val parameters = method.getParameters()
37-
38- val variables: MutableMap <String ?, Any ?> = HashMap <String ?, Any ?>()
39- for (i in args.indices) {
40- val variableName = getVariableName(parameters[i])
41- val variableValue = args[i]
42- variables.put(variableName, variableValue)
43- }
44-
45- if (template.contains(" {{it}}" ) && ! variables.containsKey(" it" )) {
46- val itValue = getValueOfVariableIt(parameters, args)
47- variables.put(" it" , itValue)
48- }
49-
50- return variables
51- }
52-
53- private fun getVariableName (parameter : Parameter ): String? {
54- val annotation = parameter.getAnnotation<V ?>(V ::class .java)
55- if (annotation != null ) {
56- return annotation.value
57- } else {
58- return parameter.getName()
59- }
60- }
61-
62- private fun getValueOfVariableIt (
63- parameters : Array <Parameter >,
64- args : Array <Any ?>? ,
65- ): String? {
66- if (args != null ) {
67- if (args.size == 1 ) {
68- val parameter = parameters[0 ]
69- if (! parameter.isAnnotationPresent(MemoryId ::class .java) &&
70- ! parameter.isAnnotationPresent(
71- UserMessage ::class .java,
72- ) &&
73- ! parameter.isAnnotationPresent(
74- UserName ::class .java,
75- ) &&
76- (
77- ! parameter.isAnnotationPresent(V ::class .java) ||
78- isAnnotatedWithIt(
79- parameter,
80- )
81- )
82- ) {
83- return asString(args[0 ])
84- }
85- }
86-
87- for (i in args.indices) {
88- if (isAnnotatedWithIt(parameters[i])) {
89- return asString(args[i])
90- }
91- }
92- }
29+ ): MutableMap <String ?, Any ?> =
30+ InternalReflectionVariableResolver .findTemplateVariables(template, method, args)
9331
94- throw IllegalConfigurationException .illegalConfiguration(
95- " Error: cannot find the value of the prompt template variable \" {{it}}\" ." ,
96- )
97- }
98-
99- private fun isAnnotatedWithIt (parameter : Parameter ): Boolean {
100- val annotation = parameter.getAnnotation<V ?>(V ::class .java)
101- return annotation != null && " it" == annotation.value
102- }
103-
104- public fun asString (arg : Any? ): String? {
32+ public fun asString (arg : Any? ): String =
10533 if (arg == null ) {
106- return " null"
34+ " null"
10735 } else if (arg is Array <* >? ) {
108- return arrayAsString(arg)
36+ arrayAsString(arg)
10937 } else if (arg.javaClass.isAnnotationPresent(StructuredPrompt ::class .java)) {
110- return StructuredPromptProcessor .toPrompt(arg).text()
38+ StructuredPromptProcessor .toPrompt(arg).text()
11139 } else {
112- return arg.toString()
40+ arg.toString()
11341 }
114- }
11542
11643 private fun arrayAsString (arg : Array <* >? ): String =
11744 if (arg == null ) {
@@ -132,50 +59,65 @@ internal object ReflectionVariableResolver {
13259 fun findUserMessageTemplateFromTheOnlyArgument (
13360 parameters : Array <Parameter >? ,
13461 args : Array <Any ?>,
135- ): Optional <String > {
136- if (parameters != null &&
62+ ): Optional <String > =
63+ if (
64+ parameters != null &&
13765 parameters.size == 1 &&
13866 parameters[0 ].getAnnotations().size == 0
13967 ) {
140- return Optional .ofNullable<String >(asString(args[0 ]))
68+ Optional .ofNullable<String >(asString(args[0 ]))
69+ } else {
70+ Optional .empty()
14171 }
142- return Optional .empty()
143- }
72+
14473
14574 fun findUserName (
14675 parameters : Array <Parameter >,
14776 args : Array <Any ?>,
14877 ): Optional <String > {
149- for (i in parameters.indices) {
78+ var result = Optional .empty<String >()
79+ for (i in args.indices) {
15080 if (parameters[i].isAnnotationPresent(UserName ::class .java)) {
151- return Optional .of<String >(args[i].toString())
81+ result = Optional .of(args[i].toString())
82+ break
15283 }
15384 }
154- return Optional .empty< String >()
85+ return result
15586 }
15687
88+ @Suppress(" ReturnCount" )
15789 fun findMemoryId (
15890 method : Method ,
15991 args : Array <Any ?>? ,
16092 ): Optional <ChatMemoryId > {
16193 if (args == null ) {
162- return Optional .empty< ChatMemoryId > ()
94+ return Optional .empty()
16395 }
96+
97+ val memoryIdParam = findMemoryIdParameter(method, args)
98+ if (memoryIdParam != null ) {
99+ val (parameter, memoryId) = memoryIdParam
100+ if (memoryId is ChatMemoryId ) {
101+ return Optional .of(memoryId)
102+ } else {
103+ throw Exceptions .illegalArgument(
104+ " The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null" ,
105+ parameter.getName(),
106+ method.getName(),
107+ )
108+ }
109+ }
110+
111+ return Optional .empty()
112+ }
113+
114+ private fun findMemoryIdParameter (method : Method , args : Array <Any ?>): Pair <Parameter , Any ?>? {
164115 for (i in args.indices) {
165116 val parameter = method.parameters[i]
166117 if (parameter.isAnnotationPresent(MemoryId ::class .java)) {
167- val memoryId = args[i]
168- if (memoryId is ChatMemoryId ) {
169- return Optional .of(memoryId)
170- } else {
171- throw Exceptions .illegalArgument(
172- " The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null" ,
173- parameter.getName(),
174- method.getName(),
175- )
176- }
118+ return Pair (parameter, args[i])
177119 }
178120 }
179- return Optional .empty()
121+ return null
180122 }
181123}
0 commit comments