Spring AI 1.0.0 M6 适配:增加知识库、工具调用(function calling)、工作流、豆包/混元/硅基流动等模型的接入

This commit is contained in:
YunaiV
2025-03-14 23:34:20 +08:00
parent c2de5d9c8c
commit e2e4b000e6
189 changed files with 4825 additions and 7145 deletions

View File

@@ -1,24 +1,33 @@
package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreProperties;
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.Lazy;
/**
* 芋道 AI 自动配置
@@ -26,9 +35,12 @@ import org.springframework.context.annotation.Lazy;
* @author fansili
*/
@AutoConfiguration
@EnableConfigurationProperties(YudaoAiProperties.class)
@EnableConfigurationProperties({ YudaoAiProperties.class,
QdrantVectorStoreProperties.class, // 解析 Qdrant 配置
RedisVectorStoreProperties.class, // 解析 Redis 配置
MilvusVectorStoreProperties.class, MilvusServiceClientProperties.class // 解析 Milvus 配置
})
@Slf4j
@Import(TongYiAutoConfiguration.class)
public class YudaoAiAutoConfiguration {
@Bean
@@ -36,33 +48,148 @@ public class YudaoAiAutoConfiguration {
return new AiModelFactoryImpl();
}
// ========== 各种 AI Client 创建 ==========
@Bean
@ConditionalOnProperty(value = "yudao.ai.deepseek.enable", havingValue = "true")
public DeepSeekChatModel deepSeekChatModel(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepSeek();
DeepSeekChatOptions options = DeepSeekChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepseek();
return buildDeepSeekChatModel(properties);
}
public DeepSeekChatModel buildDeepSeekChatModel(YudaoAiProperties.DeepSeekProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(DeepSeekChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(DeepSeekChatModel.BASE_URL)
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new DeepSeekChatModel(properties.getApiKey(), options);
return new DeepSeekChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.doubao.enable", havingValue = "true")
public DouBaoChatModel douBaoChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.DouBaoProperties properties = yudaoAiProperties.getDoubao();
return buildDouBaoChatClient(properties);
}
public DouBaoChatModel buildDouBaoChatClient(YudaoAiProperties.DouBaoProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(DouBaoChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(DouBaoChatModel.BASE_URL)
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new DouBaoChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.siliconflow.enable", havingValue = "true")
public SiliconFlowChatModel siliconFlowChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.SiliconFlowProperties properties = yudaoAiProperties.getSiliconflow();
return buildSiliconFlowChatClient(properties);
}
public SiliconFlowChatModel buildSiliconFlowChatClient(YudaoAiProperties.SiliconFlowProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(SiliconFlowChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(SiliconFlowChatModel.BASE_URL)
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new SiliconFlowChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.hunyuan.enable", havingValue = "true")
public HunYuanChatModel hunYuanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.HunYuanProperties properties = yudaoAiProperties.getHunyuan();
return buildHunYuanChatClient(properties);
}
public HunYuanChatModel buildHunYuanChatClient(YudaoAiProperties.HunYuanProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(HunYuanChatModel.MODEL_DEFAULT);
}
// 特殊:由于混元大模型不提供 deepseek而是通过知识引擎所以需要区分下 URL
if (StrUtil.isEmpty(properties.getBaseUrl())) {
properties.setBaseUrl(
StrUtil.startWithIgnoreCase(properties.getModel(), "deepseek") ? HunYuanChatModel.DEEP_SEEK_BASE_URL
: HunYuanChatModel.BASE_URL);
}
// 创建 OpenAiChatModel、HunYuanChatModel 对象
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(properties.getBaseUrl())
.apiKey(properties.getApiKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new HunYuanChatModel(openAiChatModel);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
public XingHuoChatModel xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.XingHuoProperties properties = yudaoAiProperties.getXinghuo();
XingHuoChatOptions options = XingHuoChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topK(properties.getTopK())
return buildXingHuoChatClient(properties);
}
public XingHuoChatModel buildXingHuoChatClient(YudaoAiProperties.XingHuoProperties properties) {
if (StrUtil.isEmpty(properties.getModel())) {
properties.setModel(XingHuoChatModel.MODEL_DEFAULT);
}
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(XingHuoChatModel.BASE_URL)
.apiKey(properties.getAppKey() + ":" + properties.getSecretKey())
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(properties.getModel())
.temperature(properties.getTemperature())
.maxTokens(properties.getMaxTokens())
.topP(properties.getTopP())
.build())
.toolCallingManager(getToolCallingManager())
.build();
return new XingHuoChatModel(properties.getAppKey(), properties.getSecretKey(), options);
return new XingHuoChatModel(openAiChatModel);
}
@Bean
@@ -78,44 +205,20 @@ public class YudaoAiAutoConfiguration {
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
}
// ========== rag 相关 ==========
// TODO @xin 免费版本
// @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动」
// public TransformersEmbeddingModel transformersEmbeddingClient() {
// return new TransformersEmbeddingModel(MetadataMode.EMBED);
// }
/**
* TODO @xin 默认版本先不弄,目前都先取对应的 EmbeddingModel
*/
// @Bean
// @Lazy // TODO 芋艿:临时注释,避免无法启动
// public RedisVectorStore vectorStore(TransformersEmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
// RedisProperties redisProperties) {
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(properties.getIndex())
// .withPrefix(properties.getPrefix())
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
// .build();
//
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// properties.isInitializeSchema());
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
// }
@Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动
public TokenTextSplitter tokenTextSplitter() {
//TODO @xin 配置提取
return new TokenTextSplitter(500, 100, 5, 10000, true);
}
// ========== RAG 相关 ==========
@Bean
@Lazy // TODO 芋艿:临时注释,避免无法启动
public TokenCountEstimator tokenCountEstimator() {
return new JTokkitTokenCountEstimator();
}
@Bean
public BatchingStrategy batchingStrategy() {
return new TokenCountBatchingStrategy();
}
private static ToolCallingManager getToolCallingManager() {
return SpringUtil.getBean(ToolCallingManager.class);
}
}

View File

@@ -16,11 +16,31 @@ public class YudaoAiProperties {
/**
* DeepSeek
*/
private DeepSeekProperties deepSeek;
@SuppressWarnings("SpellCheckingInspection")
private DeepSeekProperties deepseek;
/**
* 字节豆包
*/
@SuppressWarnings("SpellCheckingInspection")
private DouBaoProperties doubao;
/**
* 腾讯混元
*/
@SuppressWarnings("SpellCheckingInspection")
private HunYuanProperties hunyuan;
/**
* 硅基流动
*/
@SuppressWarnings("SpellCheckingInspection")
private SiliconFlowProperties siliconflow;
/**
* 讯飞星火
*/
@SuppressWarnings("SpellCheckingInspection")
private XingHuoProperties xinghuo;
/**
@@ -31,8 +51,62 @@ public class YudaoAiProperties {
/**
* Suno 音乐
*/
@SuppressWarnings("SpellCheckingInspection")
private SunoProperties suno;
@Data
public static class DeepSeekProperties {
private String enable;
private String apiKey;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class DouBaoProperties {
private String enable;
private String apiKey;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class HunYuanProperties {
private String enable;
private String baseUrl;
private String apiKey;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class SiliconFlowProperties {
private String enable;
private String apiKey;
private String model;
private Double temperature;
private Integer maxTokens;
private Double topP;
}
@Data
public static class XingHuoProperties {
@@ -42,22 +116,9 @@ public class YudaoAiProperties {
private String secretKey;
private String model;
private Float temperature;
private Double temperature;
private Integer maxTokens;
private Integer topK;
}
@Data
public static class DeepSeekProperties {
private String enable;
private String apiKey;
private String model;
private Float temperature;
private Integer maxTokens;
private Float topP;
private Double topP;
}

View File

@@ -0,0 +1,41 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
/**
* AI 模型类型的枚举
*
* @author 芋道源码
*/
@Getter
@RequiredArgsConstructor
public enum AiModelTypeEnum implements ArrayValuable<Integer> {
CHAT(1, "对话"),
IMAGE(2, "图片"),
VOICE(3, "语音"),
VIDEO(4, "视频"),
EMBEDDING(5, "向量"),
RERANK(6, "重排序");
/**
* 类型
*/
private final Integer type;
/**
* 类型名
*/
private final String name;
public static final Integer[] ARRAYS = Arrays.stream(values()).map(AiModelTypeEnum::getType).toArray(Integer[]::new);
@Override
public Integer[] array() {
return ARRAYS;
}
}

View File

@@ -1,8 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.enums;
import cn.iocoder.yudao.framework.common.core.ArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 模型平台
*
@@ -10,7 +13,7 @@ import lombok.Getter;
*/
@Getter
@AllArgsConstructor
public enum AiPlatformEnum {
public enum AiPlatformEnum implements ArrayValuable<String> {
// ========== 国内平台 ==========
@@ -19,6 +22,11 @@ public enum AiPlatformEnum {
DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
ZHI_PU("ZhiPu", "智谱"), // 智谱 AI
XING_HUO("XingHuo", "星火"), // 讯飞
DOU_BAO("DouBao", "豆包"), // 字节
HUN_YUAN("HunYuan", "混元"), // 腾讯
SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动
MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技
MOONSHOT("Moonshot", "月之暗灭"), // KIMI
// ========== 国外平台 ==========
@@ -41,6 +49,8 @@ public enum AiPlatformEnum {
*/
private final String name;
public static final String[] ARRAYS = Arrays.stream(values()).map(AiPlatformEnum::getPlatform).toArray(String[]::new);
public static AiPlatformEnum validatePlatform(String platform) {
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
if (platformEnum.getPlatform().equals(platform)) {
@@ -50,4 +60,9 @@ public enum AiPlatformEnum {
throw new IllegalArgumentException("非法平台: " + platform);
}
@Override
public String[] array() {
return ARRAYS;
}
}

View File

@@ -8,6 +8,8 @@ import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.vectorstore.VectorStore;
import java.util.Map;
/**
* AI Model 模型工厂的接口类
*
@@ -89,21 +91,23 @@ public interface AiModelFactory {
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @param model 模型
* @return ChatModel 对象
*/
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url);
EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model);
/**
* 基于指定配置,获得 VectorStore 对象
* <p>
*
* 如果不存在,则进行创建
*
* @param embeddingModel 嵌入模
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @param type 向量存储类
* @param embeddingModel 向量模型
* @param metadataFields 元数据字段
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url);
VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
EmbeddingModel embeddingModel,
Map<String, Class<?>> metadataFields);
}

View File

@@ -1,74 +1,117 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.RuntimeUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.azure.ai.openai.OpenAIClient;
import com.alibaba.cloud.ai.autoconfigure.dashscope.DashScopeAutoConfiguration;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.azure.ai.openai.OpenAIClientBuilder;
import io.micrometer.observation.ObservationRegistry;
import io.milvus.client.MilvusServiceClient;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import lombok.SneakyThrows;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiEmbeddingProperties;
import org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration;
import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration;
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreProperties;
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration;
import org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreProperties;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreAutoConfiguration;
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
import org.springframework.ai.minimax.MiniMaxEmbeddingOptions;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.moonshot.MoonshotChatModel;
import org.springframework.ai.moonshot.MoonshotChatOptions;
import org.springframework.ai.moonshot.api.MoonshotApi;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.qianfan.QianFanChatModel;
import org.springframework.ai.qianfan.QianFanEmbeddingModel;
import org.springframework.ai.qianfan.QianFanEmbeddingOptions;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.vectorstore.milvus.MilvusVectorStore;
import org.springframework.ai.vectorstore.observation.DefaultVectorStoreObservationConvention;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.ai.vectorstore.qdrant.QdrantVectorStore;
import org.springframework.ai.vectorstore.redis.RedisVectorStore;
import org.springframework.ai.zhipuai.*;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.data.redis.RedisProperties;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.search.Schema;
import java.io.File;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static org.springframework.ai.retry.RetryUtils.DEFAULT_RETRY_TEMPLATE;
/**
* AI Model 模型工厂的实现类
@@ -81,7 +124,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiChatModel(apiKey);
@@ -89,8 +132,18 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildYiYanChatModel(apiKey);
case DEEP_SEEK:
return buildDeepSeekChatModel(apiKey);
case DOU_BAO:
return buildDouBaoChatModel(apiKey);
case HUN_YUAN:
return buildHunYuanChatModel(apiKey, url);
case SILICON_FLOW:
return buildSiliconFlowChatModel(apiKey);
case ZHI_PU:
return buildZhiPuChatModel(apiKey, url);
case MINI_MAX:
return buildMiniMaxChatModel(apiKey, url);
case MOONSHOT:
return buildMoonshotChatModel(apiKey, url);
case XING_HUO:
return buildXingHuoChatModel(apiKey);
case OPENAI:
@@ -107,16 +160,26 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(TongYiChatModel.class);
return SpringUtil.getBean(DashScopeChatModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanChatModel.class);
case DEEP_SEEK:
return SpringUtil.getBean(DeepSeekChatModel.class);
case DOU_BAO:
return SpringUtil.getBean(DouBaoChatModel.class);
case HUN_YUAN:
return SpringUtil.getBean(HunYuanChatModel.class);
case SILICON_FLOW:
return SpringUtil.getBean(SiliconFlowChatModel.class);
case ZHI_PU:
return SpringUtil.getBean(ZhiPuAiChatModel.class);
case MINI_MAX:
return SpringUtil.getBean(MiniMaxChatModel.class);
case MOONSHOT:
return SpringUtil.getBean(MoonshotChatModel.class);
case XING_HUO:
return SpringUtil.getBean(XingHuoChatModel.class);
case OPENAI:
@@ -132,10 +195,10 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(TongYiImagesModel.class);
return SpringUtil.getBean(DashScopeImageModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class);
case ZHI_PU:
@@ -151,7 +214,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiImagesModel(apiKey);
@@ -170,9 +233,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
url);
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
.getMidjourney();
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
});
}
@@ -184,13 +249,25 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
@Override
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url);
@SuppressWarnings("EnhancedSwitchMigration")
public EmbeddingModel getOrCreateEmbeddingModel(AiPlatformEnum platform, String apiKey, String url, String model) {
String cacheKey = buildClientCacheKey(EmbeddingModel.class, platform, apiKey, url, model);
return Singleton.get(cacheKey, (Func0<EmbeddingModel>) () -> {
// TODO @xin 先测试一个
switch (platform) {
case TONG_YI:
return buildTongYiEmbeddingModel(apiKey);
return buildTongYiEmbeddingModel(apiKey, model);
case YI_YAN:
return buildYiYanEmbeddingModel(apiKey, model);
case ZHI_PU:
return buildZhiPuEmbeddingModel(apiKey, url, model);
case MINI_MAX:
return buildMiniMaxEmbeddingModel(apiKey, url, model);
case OPENAI:
return buildOpenAiEmbeddingModel(apiKey, url, model);
case AZURE_OPENAI:
return buildAzureOpenAiEmbeddingModel(apiKey, url, model);
case OLLAMA:
return buildOllamaEmbeddingModel(url, model);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@@ -198,21 +275,24 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type,
EmbeddingModel embeddingModel,
Map<String, Class<?>> metadataFields) {
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
var config = RedisVectorStore.RedisVectorStoreConfig.builder()
.withIndexName(cacheKey)
.withPrefix(prefix)
.withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
.build();
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
true);
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
if (type == SimpleVectorStore.class) {
return buildSimpleVectorStore(embeddingModel);
}
if (type == QdrantVectorStore.class) {
return buildQdrantVectorStore(embeddingModel);
}
if (type == RedisVectorStore.class) {
return buildRedisVectorStore(embeddingModel, metadataFields);
}
if (type == MilvusVectorStore.class) {
return buildMilvusVectorStore(embeddingModel);
}
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
});
}
@@ -226,29 +306,25 @@ public class AiModelFactoryImpl implements AiModelFactory {
// ========== 各种创建 spring-ai 客户端的方法 ==========
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeChatModel 方法
*/
private static TongYiChatModel buildTongYiChatModel(String key) {
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
// TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下
// TODO @芋艿:貌似阿里云不是增量返回的
// 该 issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
private static TongYiImagesModel buildTongYiImagesModel(String key) {
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
private static DashScopeChatModel buildTongYiChatModel(String key) {
DashScopeApi dashScopeApi = new DashScopeApi(key);
DashScopeChatOptions options = DashScopeChatOptions.builder().withModel(DashScopeApi.DEFAULT_CHAT_MODEL)
.withTemperature(0.7).build();
return new DashScopeChatModel(dashScopeApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link DashScopeAutoConfiguration} 的 dashScopeImageModel 方法
*/
private static DashScopeImageModel buildTongYiImagesModel(String key) {
DashScopeImageApi dashScopeImageApi = new DashScopeImageApi(key);
return new DashScopeImageModel(dashScopeImageApi);
}
/**
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanChatModel 方法
*/
private static QianFanChatModel buildYiYanChatModel(String key) {
List<String> keys = StrUtil.split(key, '|');
@@ -260,7 +336,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
/**
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanImageModel 方法
*/
private QianFanImageModel buildQianFanImageModel(String key) {
List<String> keys = StrUtil.split(key, '|');
@@ -275,47 +351,98 @@ public class AiModelFactoryImpl implements AiModelFactory {
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
*/
private static DeepSeekChatModel buildDeepSeekChatModel(String apiKey) {
return new DeepSeekChatModel(apiKey);
YudaoAiProperties.DeepSeekProperties properties = new YudaoAiProperties.DeepSeekProperties()
.setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildDeepSeekChatModel(properties);
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link YudaoAiAutoConfiguration#douBaoChatClient(YudaoAiProperties)}
*/
private ChatModel buildDouBaoChatModel(String apiKey) {
YudaoAiProperties.DouBaoProperties properties = new YudaoAiProperties.DouBaoProperties()
.setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildDouBaoChatClient(properties);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#hunYuanChatClient(YudaoAiProperties)}
*/
private ChatModel buildHunYuanChatModel(String apiKey, String url) {
YudaoAiProperties.HunYuanProperties properties = new YudaoAiProperties.HunYuanProperties()
.setBaseUrl(url).setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildHunYuanChatClient(properties);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#siliconFlowChatClient(YudaoAiProperties)}
*/
private ChatModel buildSiliconFlowChatModel(String apiKey) {
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
.setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties);
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiChatModel 方法
*/
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(url, apiKey);
return new ZhiPuAiChatModel(zhiPuAiApi);
ZhiPuAiApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiApi(apiKey)
: new ZhiPuAiApi(url, apiKey);
ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder().model(ZhiPuAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build();
return new ZhiPuAiChatModel(zhiPuAiApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiImageModel 方法
*/
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
ZhiPuAiImageApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiImageApi(apiKey)
: new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
return new ZhiPuAiImageModel(zhiPuAiApi);
}
/**
* 可参考 {@link MiniMaxAutoConfiguration} 的 miniMaxChatModel 方法
*/
private MiniMaxChatModel buildMiniMaxChatModel(String apiKey, String url) {
MiniMaxApi miniMaxApi = StrUtil.isEmpty(url) ? new MiniMaxApi(apiKey)
: new MiniMaxApi(url, apiKey);
MiniMaxChatOptions options = MiniMaxChatOptions.builder().model(MiniMaxApi.DEFAULT_CHAT_MODEL).temperature(0.7).build();
return new MiniMaxChatModel(miniMaxApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
}
/**
* 可参考 {@link MoonshotAutoConfiguration} 的 moonshotChatModel 方法
*/
private MoonshotChatModel buildMoonshotChatModel(String apiKey, String url) {
MoonshotApi moonshotApi = StrUtil.isEmpty(url)? new MoonshotApi(apiKey)
: new MoonshotApi(url, apiKey);
MoonshotChatOptions options = MoonshotChatOptions.builder().model(MoonshotApi.DEFAULT_CHAT_MODEL).build();
return new MoonshotChatModel(moonshotApi, options, getFunctionCallbackResolver(), DEFAULT_RETRY_TEMPLATE);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
*/
private static XingHuoChatModel buildXingHuoChatModel(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 3, "XingHuoChatClient 的密钥需要 (appid|appKey|secretKey) 格式");
String appKey = keys.get(1);
String secretKey = keys.get(2);
return new XingHuoChatModel(appKey, secretKey);
Assert.equals(keys.size(), 2, "XingHuoChatClient 的密钥需要 (appKey|secretKey) 格式");
YudaoAiProperties.XingHuoProperties properties = new YudaoAiProperties.XingHuoProperties()
.setAppKey(keys.get(0)).setSecretKey(keys.get(1));
return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties);
}
/**
* 可参考 {@link OpenAiAutoConfiguration}
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法
*/
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatModel(openAiApi);
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
return OpenAiChatModel.builder().openAiApi(openAiApi).toolCallingManager(getToolCallingManager()).build();
}
// TODO @芋艿:手头暂时没密钥,使用建议再测试下
/**
* 可参考 {@link AzureOpenAiAutoConfiguration}
*/
@@ -325,27 +452,28 @@ public class AiModelFactoryImpl implements AiModelFactory {
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
connectionProperties.setApiKey(apiKey);
connectionProperties.setEndpoint(url);
OpenAIClient openAIClient = azureOpenAiAutoConfiguration.openAIClient(connectionProperties);
OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
// 获取 AzureOpenAiChatProperties 对象
AzureOpenAiChatProperties chatProperties = SpringUtil.getBean(AzureOpenAiChatProperties.class);
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties, null, null);
return azureOpenAiAutoConfiguration.azureOpenAiChatModel(openAIClient, chatProperties,
getToolCallingManager(), null, null);
}
/**
* 可参考 {@link OpenAiAutoConfiguration}
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiImageModel 方法
*/
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = OpenAiImageApi.builder().baseUrl(url).apiKey(openAiToken).build();
return new OpenAiImageModel(openAiApi);
}
/**
* 可参考 {@link OllamaAutoConfiguration}
* 可参考 {@link OllamaAutoConfiguration} 的 ollamaApi 方法
*/
private static OllamaChatModel buildOllamaChatModel(String url) {
OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatModel(ollamaApi);
return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build();
}
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
@@ -357,12 +485,234 @@ public class AiModelFactoryImpl implements AiModelFactory {
// ========== 各种创建 EmbeddingModel 的方法 ==========
/**
* 可参考 {@link TongYiAutoConfiguration#tongYiTextEmbeddingClient(TextEmbedding, TongYiConnectionProperties)}
* 可参考 {@link DashScopeAutoConfiguration} 的 dashscopeEmbeddingModel 方法
*/
private EmbeddingModel buildTongYiEmbeddingModel(String apiKey) {
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration} 的 zhiPuAiEmbeddingModel 方法
*/
private ZhiPuAiEmbeddingModel buildZhiPuEmbeddingModel(String apiKey, String url, String model) {
ZhiPuAiApi zhiPuAiApi = StrUtil.isEmpty(url) ? new ZhiPuAiApi(apiKey)
: new ZhiPuAiApi(url, apiKey);
ZhiPuAiEmbeddingOptions zhiPuAiEmbeddingOptions = ZhiPuAiEmbeddingOptions.builder().model(model).build();
return new ZhiPuAiEmbeddingModel(zhiPuAiApi, MetadataMode.EMBED, zhiPuAiEmbeddingOptions);
}
/**
* 可参考 {@link MiniMaxAutoConfiguration} 的 miniMaxEmbeddingModel 方法
*/
private EmbeddingModel buildMiniMaxEmbeddingModel(String apiKey, String url, String model) {
MiniMaxApi miniMaxApi = StrUtil.isEmpty(url)? new MiniMaxApi(apiKey)
: new MiniMaxApi(url, apiKey);
MiniMaxEmbeddingOptions miniMaxEmbeddingOptions = MiniMaxEmbeddingOptions.builder().model(model).build();
return new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED, miniMaxEmbeddingOptions);
}
/**
* 可参考 {@link QianFanAutoConfiguration} 的 qianFanEmbeddingModel 方法
*/
private QianFanEmbeddingModel buildYiYanEmbeddingModel(String key, String model) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
QianFanEmbeddingOptions qianFanEmbeddingOptions = QianFanEmbeddingOptions.builder().model(model).build();
return new QianFanEmbeddingModel(qianFanApi, MetadataMode.EMBED, qianFanEmbeddingOptions);
}
private OllamaEmbeddingModel buildOllamaEmbeddingModel(String url, String model) {
OllamaApi ollamaApi = new OllamaApi(url);
OllamaOptions ollamaOptions = OllamaOptions.builder().model(model).build();
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
}
/**
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiEmbeddingModel 方法
*/
private OpenAiEmbeddingModel buildOpenAiEmbeddingModel(String openAiToken, String url, String model) {
url = StrUtil.blankToDefault(url, OpenAiApiConstants.DEFAULT_BASE_URL);
OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(url).apiKey(openAiToken).build();
OpenAiEmbeddingOptions openAiEmbeddingProperties = OpenAiEmbeddingOptions.builder().model(model).build();
return new OpenAiEmbeddingModel(openAiApi, MetadataMode.EMBED, openAiEmbeddingProperties);
}
// TODO @芋艿:手头暂时没密钥,使用建议再测试下
/**
* 可参考 {@link AzureOpenAiAutoConfiguration} 的 azureOpenAiEmbeddingModel 方法
*/
private AzureOpenAiEmbeddingModel buildAzureOpenAiEmbeddingModel(String apiKey, String url, String model) {
AzureOpenAiAutoConfiguration azureOpenAiAutoConfiguration = new AzureOpenAiAutoConfiguration();
// 创建 OpenAIClient 对象
AzureOpenAiConnectionProperties connectionProperties = new AzureOpenAiConnectionProperties();
connectionProperties.setApiKey(apiKey);
return new TongYiAutoConfiguration().tongYiTextEmbeddingClient(SpringUtil.getBean(TextEmbedding.class), connectionProperties);
connectionProperties.setEndpoint(url);
OpenAIClientBuilder openAIClient = azureOpenAiAutoConfiguration.openAIClientBuilder(connectionProperties, null);
// 获取 AzureOpenAiChatProperties 对象
AzureOpenAiEmbeddingProperties embeddingProperties = SpringUtil.getBean(AzureOpenAiEmbeddingProperties.class);
return azureOpenAiAutoConfiguration.azureOpenAiEmbeddingModel(openAIClient, embeddingProperties,
null, null);
}
// ========== 各种创建 VectorStore 的方法 ==========
/**
* 注意:仅适合本地测试使用,生产建议还是使用 Qdrant、Milvus 等
*/
@SneakyThrows
@SuppressWarnings("ResultOfMethodCallIgnored")
private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
// 启动加载
File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
if (!file.exists()) {
FileUtil.mkParentDirs(file);
file.createNewFile();
} else if (file.length() > 0) {
vectorStore.load(file);
}
// 定时持久化,每分钟一次
Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
timer.scheduleAtFixedRate(new TimerTask() {
@Override
public void run() {
vectorStore.save(file);
}
}, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
// 关闭时,进行持久化
RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
return vectorStore;
}
/**
* 参考 {@link QdrantVectorStoreAutoConfiguration} 的 vectorStore 方法
*/
@SneakyThrows
private QdrantVectorStore buildQdrantVectorStore(EmbeddingModel embeddingModel) {
QdrantVectorStoreAutoConfiguration configuration = new QdrantVectorStoreAutoConfiguration();
QdrantVectorStoreProperties properties = SpringUtil.getBean(QdrantVectorStoreProperties.class);
// 参考 QdrantVectorStoreAutoConfiguration 实现,创建 QdrantClient 对象
QdrantGrpcClient.Builder grpcClientBuilder = QdrantGrpcClient.newBuilder(
properties.getHost(), properties.getPort(), properties.isUseTls());
if (StrUtil.isNotEmpty(properties.getApiKey())) {
grpcClientBuilder.withApiKey(properties.getApiKey());
}
QdrantClient qdrantClient = new QdrantClient(grpcClientBuilder.build());
// 创建 QdrantVectorStore 对象
QdrantVectorStore vectorStore = configuration.vectorStore(embeddingModel, properties, qdrantClient,
getObservationRegistry(), getCustomObservationConvention(), getBatchingStrategy());
// 初始化索引
vectorStore.afterPropertiesSet();
return vectorStore;
}
/**
* 参考 {@link RedisVectorStoreAutoConfiguration} 的 vectorStore 方法
*/
private RedisVectorStore buildRedisVectorStore(EmbeddingModel embeddingModel,
Map<String, Class<?>> metadataFields) {
// 创建 JedisPooled 对象
RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
JedisPooled jedisPooled = new JedisPooled(redisProperties.getHost(), redisProperties.getPort());
// 创建 RedisVectorStoreProperties 对象
RedisVectorStoreAutoConfiguration configuration = new RedisVectorStoreAutoConfiguration();
RedisVectorStoreProperties properties = SpringUtil.getBean(RedisVectorStoreProperties.class);
RedisVectorStore redisVectorStore = RedisVectorStore.builder(jedisPooled, embeddingModel)
.indexName(properties.getIndex()).prefix(properties.getPrefix())
.initializeSchema(properties.isInitializeSchema())
.metadataFields(convertList(metadataFields.entrySet(), entry -> {
String fieldName = entry.getKey();
Class<?> fieldType = entry.getValue();
if (Number.class.isAssignableFrom(fieldType)) {
return RedisVectorStore.MetadataField.numeric(fieldName);
}
if (Boolean.class.isAssignableFrom(fieldType)) {
return RedisVectorStore.MetadataField.tag(fieldName);
}
return RedisVectorStore.MetadataField.text(fieldName);
}))
.observationRegistry(getObservationRegistry().getObject())
.customObservationConvention(getCustomObservationConvention().getObject())
.batchingStrategy(getBatchingStrategy())
.build();
// 初始化索引
redisVectorStore.afterPropertiesSet();
return redisVectorStore;
}
/**
* 参考 {@link MilvusVectorStoreAutoConfiguration} 的 vectorStore 方法
*/
@SneakyThrows
private MilvusVectorStore buildMilvusVectorStore(EmbeddingModel embeddingModel) {
MilvusVectorStoreAutoConfiguration configuration = new MilvusVectorStoreAutoConfiguration();
// 获取配置属性
MilvusVectorStoreProperties serverProperties = SpringUtil.getBean(MilvusVectorStoreProperties.class);
MilvusServiceClientProperties clientProperties = SpringUtil.getBean(MilvusServiceClientProperties.class);
// 创建 MilvusServiceClient 对象
MilvusServiceClient milvusClient = configuration.milvusClient(serverProperties, clientProperties,
new MilvusServiceClientConnectionDetails() {
@Override
public String getHost() {
return clientProperties.getHost();
}
@Override
public int getPort() {
return clientProperties.getPort();
}
}
);
// 创建 MilvusVectorStore 对象
MilvusVectorStore vectorStore = configuration.vectorStore(milvusClient, embeddingModel, serverProperties,
getBatchingStrategy(), getObservationRegistry(), getCustomObservationConvention());
// 初始化索引
vectorStore.afterPropertiesSet();
return vectorStore;
}
private static ObjectProvider<ObservationRegistry> getObservationRegistry() {
return new ObjectProvider<>() {
@Override
public ObservationRegistry getObject() throws BeansException {
return SpringUtil.getBean(ObservationRegistry.class);
}
};
}
private static ObjectProvider<VectorStoreObservationConvention> getCustomObservationConvention() {
return new ObjectProvider<>() {
@Override
public VectorStoreObservationConvention getObject() throws BeansException {
return new DefaultVectorStoreObservationConvention();
}
};
}
private static BatchingStrategy getBatchingStrategy() {
return SpringUtil.getBean(BatchingStrategy.class);
}
private static ToolCallingManager getToolCallingManager() {
return SpringUtil.getBean(ToolCallingManager.class);
}
private static FunctionCallbackResolver getFunctionCallbackResolver() {
return SpringUtil.getBean(FunctionCallbackResolver.class);
}
}

View File

@@ -1,166 +1,45 @@
package cn.iocoder.yudao.framework.ai.core.model.deepseek;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.lang.Assert;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions.MODEL_DEFAULT;
/**
* DeepSeek {@link ChatModel} 实现类
*
* @author fansili
*/
@Slf4j
@RequiredArgsConstructor
public class DeepSeekChatModel implements ChatModel {
private static final String BASE_URL = "https://api.deepseek.com";
public static final String BASE_URL = "https://api.deepseek.com";
private final DeepSeekChatOptions defaultOptions;
private final RetryTemplate retryTemplate;
public static final String MODEL_DEFAULT = "deepseek-chat";
/**
* DeepSeek 兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
*
* 不过要注意DeepSeek 没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
* 兼容 OpenAI 接口,进行复用
*/
private final OpenAiApi openAiApi;
public DeepSeekChatModel(String apiKey) {
this(apiKey, DeepSeekChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
}
public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options) {
this(apiKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options, RetryTemplate retryTemplate) {
Assert.notEmpty(apiKey, "apiKey 不能为空");
Assert.notNull(options, "options 不能为空");
Assert.notNull(retryTemplate, "retryTemplate 不能为空");
this.openAiApi = new OpenAiApi(BASE_URL, apiKey);
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
// 1.1 发起调用
ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
// 1.2 校验结果
OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(ListUtil.of());
}
List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
if (choices == null) {
log.warn("No choices returned for prompt: {}", prompt);
return new ChatResponse(ListUtil.of());
}
// 2. 转换 ChatResponse 返回
List<Generation> generations = choices.stream().map(choice -> {
Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
if (choice.finishReason() != null) {
generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();
return new ChatResponse(generations,
OpenAiChatResponseMetadata.from(completionEntity.getBody()));
});
}
private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();
OpenAiApi.ChatCompletionMessage message = choice.message();
if (message.role() != null) {
map.put("role", message.role().name());
}
if (choice.finishReason() != null) {
map.put("finishReason", choice.finishReason().name());
}
map.put("id", id);
return map;
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
return this.retryTemplate.execute(ctx -> {
// 1. 发起调用
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
return response.map(chatCompletion -> {
String id = chatCompletion.id();
// 2. 转换 ChatResponse 返回
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
String role = (choice.delta().role() != null ? choice.delta().role().name() : "");
if (choice.finishReason() == OpenAiApi.ChatCompletionFinishReason.STOP) {
// 兜底处理 DeepSeek 返回 STOP 时role 为空的情况
role = OpenAiApi.ChatCompletionMessage.Role.ASSISTANT.name();
}
Generation generation = new Generation(choice.delta().content(),
Map.of("id", id, "role", role, "finishReason", finish));
if (choice.finishReason() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();
return new ChatResponse(generations);
});
});
}
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 1. 构建 ChatCompletionMessage 对象
List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
// 2.1 补充 prompt 内置的 options
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, OpenAiChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
} else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}
// 2.2 补充默认 options
if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
}
return request;
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return DeepSeekChatOptions.fromOptions(defaultOptions);
return openAiChatModel.getDefaultOptions();
}
}

View File

@@ -1,55 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.deepseek;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.prompt.ChatOptions;
/**
* DeepSeek {@link ChatOptions} 实现类
*
* 参考文档:<a href="https://platform.deepseek.com/api-docs/zh-cn/">快速开始</a>
*
* @author fansili
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class DeepSeekChatOptions implements ChatOptions {
public static final String MODEL_DEFAULT = "deepseek-chat";
/**
* 模型
*/
private String model;
/**
* 温度
*/
private Float temperature;
/**
* 最大 Token
*/
private Integer maxTokens;
/**
* topP
*/
private Float topP;
@Override
public Integer getTopK() {
return null;
}
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
return DeepSeekChatOptions.builder()
.model(fromOptions.getModel())
.temperature(fromOptions.getTemperature())
.maxTokens(fromOptions.getMaxTokens())
.topP(fromOptions.getTopP())
.build();
}
}

View File

@@ -0,0 +1,45 @@
package cn.iocoder.yudao.framework.ai.core.model.doubao;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
/**
* 字节豆包 {@link ChatModel} 实现类
*
* @author fansili
*/
@Slf4j
@RequiredArgsConstructor
public class DouBaoChatModel implements ChatModel {
public static final String BASE_URL = "https://ark.cn-beijing.volces.com/api";
public static final String MODEL_DEFAULT = "doubao-1-5-lite-32k-250115";
/**
* 兼容 OpenAI 接口,进行复用
*/
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return openAiChatModel.getDefaultOptions();
}
}

View File

@@ -0,0 +1,52 @@
package cn.iocoder.yudao.framework.ai.core.model.hunyuan;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
/**
* 腾云混元 {@link ChatModel} 实现类
*
* 1. 混元大模型:基于 <a href="https://cloud.tencent.com/document/product/1729/111007">知识引擎原子能力</a> 实现
* 2. 知识引擎原子能力:基于 <a href="https://cloud.tencent.com/document/product/1772/115969">知识引擎原子能力</a> 实现
*
* @author fansili
*/
@Slf4j
@RequiredArgsConstructor
public class HunYuanChatModel implements ChatModel {
public static final String BASE_URL = "https://api.hunyuan.cloud.tencent.com";
public static final String MODEL_DEFAULT = "hunyuan-turbo";
public static final String DEEP_SEEK_BASE_URL = "https://api.lkeap.cloud.tencent.com";
public static final String DEEP_SEEK_MODEL_DEFAULT = "deepseek-v3";
/**
* 兼容 OpenAI 接口,进行复用
*/
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return openAiChatModel.getDefaultOptions();
}
}

View File

@@ -8,9 +8,9 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
@@ -50,7 +50,10 @@ public class MidjourneyApi {
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
this.webClient = WebClient.builder()
.baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(apiKey))
.defaultHeaders(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setBearerAuth(apiKey);
})
.build();
this.notifyUrl = notifyUrl;
}

View File

@@ -0,0 +1,47 @@
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
/**
* 硅基流动 {@link ChatModel} 实现类
*
* 1. API 文档:<a href="https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions">API 文档</a>
*
* @author fansili
*/
@Slf4j
@RequiredArgsConstructor
public class SiliconFlowChatModel implements ChatModel {
public static final String BASE_URL = "https://api.siliconflow.cn";
public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
/**
* 兼容 OpenAI 接口,进行复用
*/
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return openAiChatModel.getDefaultOptions();
}
}

View File

@@ -1,163 +1,45 @@
package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.lang.Assert;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions.MODEL_DEFAULT;
/**
* 讯飞星火 {@link ChatModel} 实现类
*
* @author fansili
*/
@Slf4j
@RequiredArgsConstructor
public class XingHuoChatModel implements ChatModel {
private static final String BASE_URL = "https://spark-api-open.xf-yun.com";
public static final String BASE_URL = "https://spark-api-open.xf-yun.com";
private final XingHuoChatOptions defaultOptions;
private final RetryTemplate retryTemplate;
public static final String MODEL_DEFAULT = "generalv3.5";
/**
* 星火兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
*
* 不过要注意,星火没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
* 兼容 OpenAI 接口,进行复用
*/
private final OpenAiApi openAiApi;
public XingHuoChatModel(String apiKey, String secretKey) {
this(apiKey, secretKey,
XingHuoChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
}
public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options) {
this(apiKey, secretKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options, RetryTemplate retryTemplate) {
Assert.notEmpty(apiKey, "apiKey 不能为空");
Assert.notEmpty(secretKey, "secretKey 不能为空");
Assert.notNull(options, "options 不能为空");
Assert.notNull(retryTemplate, "retryTemplate 不能为空");
this.openAiApi = new OpenAiApi(BASE_URL, apiKey + ":" + secretKey);
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}
private final OpenAiChatModel openAiChatModel;
@Override
public ChatResponse call(Prompt prompt) {
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
// 1.1 发起调用
ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
// 1.2 校验结果
OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(ListUtil.of());
}
List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
if (choices == null) {
log.warn("No choices returned for prompt: {}", prompt);
return new ChatResponse(ListUtil.of());
}
// 2. 转换 ChatResponse 返回
List<Generation> generations = choices.stream().map(choice -> {
Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
if (choice.finishReason() != null) {
generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();
return new ChatResponse(generations,
OpenAiChatResponseMetadata.from(completionEntity.getBody()));
});
}
private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();
OpenAiApi.ChatCompletionMessage message = choice.message();
if (message.role() != null) {
map.put("role", message.role().name());
}
if (choice.finishReason() != null) {
map.put("finishReason", choice.finishReason().name());
}
map.put("id", id);
return map;
return openAiChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
return this.retryTemplate.execute(ctx -> {
// 1. 发起调用
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
return response.map(chatCompletion -> {
String id = chatCompletion.id();
// 2. 转换 ChatResponse 返回
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
Generation generation = new Generation(choice.delta().content(),
Map.of("id", id, "role", choice.delta().role().name(), "finishReason", finish));
if (choice.finishReason() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(choice.finishReason().name(), null));
}
return generation;
}).toList();
return new ChatResponse(generations);
});
});
}
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 1. 构建 ChatCompletionMessage 对象
List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
// 2.1 补充 prompt 内置的 options
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, OpenAiChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
} else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}
// 2.2 补充默认 options
if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
}
return request;
return openAiChatModel.stream(prompt);
}
@Override
public ChatOptions getDefaultOptions() {
return XingHuoChatOptions.fromOptions(defaultOptions);
return openAiChatModel.getDefaultOptions();
}
}

View File

@@ -1,55 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.prompt.ChatOptions;
/**
* 讯飞星火 {@link ChatOptions} 实现类
*
* 参考文档:<a href="https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html">HTTP 调用</a>
*
* @author fansili
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class XingHuoChatOptions implements ChatOptions {
public static final String MODEL_DEFAULT = "generalv3.5";
/**
* 模型
*/
private String model;
/**
* 温度
*/
private Float temperature;
/**
* 最大 Token
*/
private Integer maxTokens;
/**
* K 个候选
*/
private Integer topK;
@Override
public Float getTopP() {
return null;
}
public static XingHuoChatOptions fromOptions(XingHuoChatOptions fromOptions) {
return XingHuoChatOptions.builder()
.model(fromOptions.getModel())
.temperature(fromOptions.getTemperature())
.maxTokens(fromOptions.getMaxTokens())
.topK(fromOptions.getTopK())
.build();
}
}

View File

@@ -2,17 +2,19 @@ package cn.iocoder.yudao.framework.ai.core.util;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.moonshot.MoonshotChatOptions;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import java.util.Set;
/**
* Spring AI 工具类
*
@@ -21,26 +23,42 @@ import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
public class AiUtils {
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
return buildChatOptions(platform, model, temperature, maxTokens, null);
}
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
Set<String> toolNames) {
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
.withFunctions(toolNames).build();
case YI_YAN:
return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case DEEP_SEEK:
return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
case ZHI_PU:
return ZhiPuAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
case MINI_MAX:
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
case MOONSHOT:
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case DEEP_SEEK: // 复用 OpenAI 客户端
case DOU_BAO: // 复用 OpenAI 客户端
case HUN_YUAN: // 复用 OpenAI 客户端
case XING_HUO: // 复用 OpenAI 客户端
case SILICON_FLOW: // 复用 OpenAI 客户端
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build();
case AZURE_OPENAI:
// TODO 芋艿:貌似没 model 字段???!
return AzureOpenAiChatOptions.builder().withDeploymentName(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
.toolNames(toolNames).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@@ -56,8 +74,8 @@ public class AiUtils {
if (MessageType.SYSTEM.getValue().equals(type)) {
return new SystemMessage(content);
}
if (MessageType.FUNCTION.getValue().equals(type)) {
return new FunctionMessage(content);
if (MessageType.TOOL.getValue().equals(type)) {
throw new UnsupportedOperationException("暂不支持 tool 消息:" + content);
}
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
}

View File

@@ -4,7 +4,10 @@
* models 包路径:
* 1. xinghuo 包:【讯飞】星火,自己实现
* 2. deepseek 包【深度求索】DeepSeek自己实现
* 3. midjourney 包Midjourney API对接 https://github.com/novicezk/midjourney-proxy 实现
* 4. suno 包Suno API对接 https://github.com/gcui-art/suno-api 实现
* 3. doubao 包【字节豆包】DouBao自己实现
* 4. hunyuan 包【腾讯混元】HunYuan自己实现
* 5. siliconflow 包【硅基硅流】SiliconFlow自己实现
* 6. midjourney 包Midjourney API对接 https://github.com/novicezk/midjourney-proxy 实现
* 7. suno 包Suno API对接 https://github.com/gcui-art/suno-api 实现
*/
package cn.iocoder.yudao.framework.ai;

View File

@@ -1,253 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechModel;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechProperties;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionModel;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingModel;
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.audio.asr.transcription.Transcription;
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.ApiKey;
import com.alibaba.dashscope.utils.Constants;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@AutoConfiguration
@ConditionalOnClass({
MessageManager.class,
TongYiChatModel.class,
TongYiImagesModel.class,
TongYiAudioSpeechModel.class,
TongYiTextEmbeddingModel.class,
TongYiAudioTranscriptionModel.class
})
@EnableConfigurationProperties({
TongYiChatProperties.class,
TongYiImagesProperties.class,
TongYiAudioSpeechProperties.class,
TongYiConnectionProperties.class,
TongYiTextEmbeddingProperties.class,
TongYiAudioTranscriptionProperties.class
})
public class TongYiAutoConfiguration {
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public Generation generation() {
return new Generation();
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public MessageManager msgManager() {
return new MessageManager(10);
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public ImageSynthesis imageSynthesis() {
return new ImageSynthesis();
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public SpeechSynthesizer speechSynthesizer() {
return new SpeechSynthesizer();
}
@Bean
@ConditionalOnMissingBean
public Transcription transcription() {
return new Transcription();
}
@Bean
@ConditionalOnMissingBean
public TextEmbedding textEmbedding() {
return new TextEmbedding();
}
@Bean
@ConditionalOnMissingBean
public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
FunctionCallbackContext manager = new FunctionCallbackContext();
manager.setApplicationContext(context);
return manager;
}
@Bean
@ConditionalOnProperty(
prefix = TongYiChatProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiChatModel tongYiChatClient(Generation generation,
TongYiChatProperties chatOptions,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiChatModel(generation, chatOptions.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiImagesProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiImagesModel tongYiImagesClient(
ImageSynthesis imageSynthesis,
TongYiImagesProperties imagesOptions,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiImagesModel(imageSynthesis, imagesOptions.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiAudioSpeechProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiAudioSpeechModel tongYiAudioSpeechClient(
SpeechSynthesizer speechSynthesizer,
TongYiAudioSpeechProperties speechProperties,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiAudioSpeechModel(speechSynthesizer, speechProperties.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiAudioTranscriptionProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiAudioTranscriptionModel tongYiAudioTranscriptionClient(
Transcription transcription,
TongYiAudioTranscriptionProperties transcriptionProperties,
TongYiConnectionProperties connectionProperties) {
settingApiKey(connectionProperties);
return new TongYiAudioTranscriptionModel(
transcriptionProperties.getOptions(),
transcription
);
}
@Bean
@ConditionalOnProperty(
prefix = TongYiTextEmbeddingProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiTextEmbeddingModel tongYiTextEmbeddingClient(
TextEmbedding textEmbedding,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiTextEmbeddingModel(textEmbedding);
}
/**
* Setting the API key.
* @param connectionProperties {@link TongYiConnectionProperties}
*/
private void settingApiKey(TongYiConnectionProperties connectionProperties) {
String apiKey;
try {
// It is recommended to set the key by defining the api-key in an environment variable.
var envKey = System.getenv(TongYiConstants.SCA_AI_TONGYI_API_KEY);
if (Objects.nonNull(envKey)) {
Constants.apiKey = envKey;
return;
}
if (Objects.nonNull(connectionProperties.getApiKey())) {
apiKey = connectionProperties.getApiKey();
}
else {
apiKey = ApiKey.getApiKey(null);
}
Constants.apiKey = apiKey;
}
catch (NoApiKeyException e) {
throw new TongYiException(e.getMessage());
}
}
}

View File

@@ -1,52 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi;
import org.springframework.boot.context.properties.ConfigurationProperties;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* Spring Cloud Alibaba AI TongYi LLM connection properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiConnectionProperties.CONFIG_PREFIX)
public class TongYiConnectionProperties {
/**
* Spring Cloud Alibaba AI connection configuration Prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "tongyi";
/**
* TongYi LLM API key.
*/
private String apiKey;
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
}

View File

@@ -1,40 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio;
/**
* More models see: https://help.aliyun.com/zh/dashscope/model-list?spm=a2c4g.11186623.0.i5
* Support all models in list.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public final class AudioSpeechModels {
private AudioSpeechModels() {
}
/**
* Male Voice of the Tongue(舌尖男声).
* zh & en.
* Default sample rate: 48 Hz.
*/
public static final String SAMBERT_ZHICHU_V1 = "sambert-zhichu-v1";
}

View File

@@ -1,43 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public final class AudioTranscriptionModels {
private AudioTranscriptionModels() {
}
/**
* Paraformer Chinese and English speech recognition model supports audio or video speech recognition with a sampling rate of 16kHz or above.
*/
public static final String Paraformer_V1 = "paraformer-v1";
/**
* Paraformer Chinese speech recognition model, support 8kHz telephone speech recognition.
*/
public static final String Paraformer_8K_V1 = "paraformer-8k-v1";
/**
* The Paraformer multilingual speech recognition model supports audio or video speech recognition with a sample rate of 16kHz or above.
*/
public static final String Paraformer_MTL_V1 = "paraformer-mtl-v1";
}

View File

@@ -1,228 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.cloud.ai.tongyi.audio.speech.api.*;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisParam;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
import com.alibaba.dashscope.common.ResultCallback;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.nio.ByteBuffer;
/**
* TongYiAudioSpeechClient is a client for TongYi audio speech service for Spring Cloud Alibaba AI.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechModel implements SpeechModel, SpeechStreamModel {
private final Logger logger = LoggerFactory.getLogger(getClass());
/**
* Default speed rate.
*/
private static final float SPEED_RATE = 1.0f;
/**
* TongYi models api.
*/
private final SpeechSynthesizer speechSynthesizer;
/**
* TongYi models options.
*/
private final TongYiAudioSpeechOptions defaultOptions;
/**
* TongYiAudioSpeechClient constructor.
* @param speechSynthesizer the speech synthesizer
*/
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer) {
this(speechSynthesizer, null);
}
/**
* TongYiAudioSpeechClient constructor.
* @param speechSynthesizer the speech synthesizer
* @param tongYiAudioOptions the tongYi audio options
*/
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer, TongYiAudioSpeechOptions tongYiAudioOptions) {
Assert.notNull(speechSynthesizer, "speechSynthesizer must not be null");
Assert.notNull(tongYiAudioOptions, "tongYiAudioOptions must not be null");
this.speechSynthesizer = speechSynthesizer;
this.defaultOptions = tongYiAudioOptions;
}
/**
* Call the TongYi audio speech service.
* @param text the text message to be converted to audio.
* @return the audio byte buffer.
*/
@Override
public ByteBuffer call(String text) {
var speechRequest = new SpeechPrompt(text);
return call(speechRequest).getResult().getOutput();
}
/**
* Call the TongYi audio speech service.
* @param prompt the speech prompt.
* @return the speech response.
*/
@Override
public SpeechResponse call(SpeechPrompt prompt) {
var SCASpeechParam = merge(prompt.getOptions());
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
speechSynthesisParams.setText(prompt.getInstructions().getText());
logger.info(speechSynthesisParams.toString());
var res = speechSynthesizer.call(speechSynthesisParams);
return convert(res, null);
}
/**
* Call the TongYi audio speech service.
* @param prompt the speech prompt.
* @param callback the result callback.
* {@link SpeechSynthesizer#call(SpeechSynthesisParam, ResultCallback)}
*/
public void call(SpeechPrompt prompt, ResultCallback<SpeechSynthesisResult> callback) {
var SCASpeechParam = merge(prompt.getOptions());
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
speechSynthesisParams.setText(prompt.getInstructions().getText());
speechSynthesizer.call(speechSynthesisParams, callback);
}
/**
* Stream the TongYi audio speech service.
* @param prompt the speech prompt.
* @return the speech response.
* {@link SpeechSynthesizer#streamCall(SpeechSynthesisParam)}
*/
@Override
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
var SCASpeechParam = merge(prompt.getOptions());
Flowable<SpeechSynthesisResult> resultFlowable = speechSynthesizer
.streamCall(toSpeechSynthesisParams(SCASpeechParam));
return Flux.from(resultFlowable)
.flatMap(
res -> Flux.just(res.getAudioFrame())
.map(audio -> {
var speech = new Speech(audio);
var respMetadata = TongYiAudioSpeechResponseMetadata.from(res);
return new SpeechResponse(speech, respMetadata);
})
).publishOn(Schedulers.parallel());
}
public TongYiAudioSpeechOptions merge(TongYiAudioSpeechOptions target) {
var mergeBuilder = TongYiAudioSpeechOptions.builder();
mergeBuilder.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel() : target.getModel());
mergeBuilder.withPitch(defaultOptions.getPitch() != null ? defaultOptions.getPitch() : target.getPitch());
mergeBuilder.withRate(defaultOptions.getRate() != null ? defaultOptions.getRate() : target.getRate());
mergeBuilder.withFormat(defaultOptions.getFormat() != null ? defaultOptions.getFormat() : target.getFormat());
mergeBuilder.withSampleRate(defaultOptions.getSampleRate() != null ? defaultOptions.getSampleRate() : target.getSampleRate());
mergeBuilder.withTextType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : target.getTextType());
mergeBuilder.withVolume(defaultOptions.getVolume() != null ? defaultOptions.getVolume() : target.getVolume());
mergeBuilder.withEnablePhonemeTimestamp(defaultOptions.isEnablePhonemeTimestamp() != null ? defaultOptions.isEnablePhonemeTimestamp() : target.isEnablePhonemeTimestamp());
mergeBuilder.withEnableWordTimestamp(defaultOptions.isEnableWordTimestamp() != null ? defaultOptions.isEnableWordTimestamp() : target.isEnableWordTimestamp());
return mergeBuilder.build();
}
public SpeechSynthesisParam toSpeechSynthesisParams(TongYiAudioSpeechOptions source) {
var mergeBuilder = SpeechSynthesisParam.builder();
mergeBuilder.model(source.getModel() != null ? source.getModel() : AudioSpeechModels.SAMBERT_ZHICHU_V1);
mergeBuilder.text(source.getText() != null ? source.getText() : "");
if (source.getFormat() != null) {
mergeBuilder.format(source.getFormat());
}
if (source.getRate() != null) {
mergeBuilder.rate(source.getRate());
}
if (source.getPitch() != null) {
mergeBuilder.pitch(source.getPitch());
}
if (source.getTextType() != null) {
mergeBuilder.textType(source.getTextType());
}
if (source.getSampleRate() != null) {
mergeBuilder.sampleRate(source.getSampleRate());
}
if (source.isEnablePhonemeTimestamp() != null) {
mergeBuilder.enablePhonemeTimestamp(source.isEnablePhonemeTimestamp());
}
if (source.isEnableWordTimestamp() != null) {
mergeBuilder.enableWordTimestamp(source.isEnableWordTimestamp());
}
if (source.getVolume() != null) {
mergeBuilder.volume(source.getVolume());
}
return mergeBuilder.build();
}
/**
* Convert the TongYi audio speech service result to the speech response.
* @param result the audio byte buffer.
* @param synthesisResult the synthesis result.
* @return the speech response.
*/
private SpeechResponse convert(ByteBuffer result, SpeechSynthesisResult synthesisResult) {
if (synthesisResult == null) {
return new SpeechResponse(new Speech(result));
}
var responseMetadata = TongYiAudioSpeechResponseMetadata.from(synthesisResult);
var speech = new Speech(synthesisResult.getAudioFrame());
return new SpeechResponse(speech, responseMetadata);
}
}

View File

@@ -1,261 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisTextType;
import org.springframework.ai.model.ModelOptions;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechOptions implements ModelOptions {
/**
* Audio Speech models.
*/
private String model = AudioSpeechModels.SAMBERT_ZHICHU_V1;
/**
* Text content.
*/
private String text;
/**
* Input text type.
*/
private SpeechSynthesisTextType textType = SpeechSynthesisTextType.PLAIN_TEXT;
/**
* synthesis audio format.
*/
private SpeechSynthesisAudioFormat format = SpeechSynthesisAudioFormat.WAV;
/**
* synthesis audio sample rate.
*/
private Integer sampleRate = 16000;
/**
* synthesis audio volume.
*/
private Integer volume = 50;
/**
* synthesis audio speed.
*/
private Float rate = 1.0f;
/**
* synthesis audio pitch.
*/
private Float pitch = 1.0f;
/**
* enable word level timestamp.
*/
private Boolean enableWordTimestamp = false;
/**
* enable phoneme level timestamp.
*/
private Boolean enablePhonemeTimestamp = false;
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public SpeechSynthesisTextType getTextType() {
return textType;
}
public void setTextType(SpeechSynthesisTextType textType) {
this.textType = textType;
}
public SpeechSynthesisAudioFormat getFormat() {
return format;
}
public void setFormat(SpeechSynthesisAudioFormat format) {
this.format = format;
}
public Integer getSampleRate() {
return sampleRate;
}
public void setSampleRate(Integer sampleRate) {
this.sampleRate = sampleRate;
}
public Integer getVolume() {
return volume;
}
public void setVolume(Integer volume) {
this.volume = volume;
}
public Float getRate() {
return rate;
}
public void setRate(Float rate) {
this.rate = rate;
}
public Float getPitch() {
return pitch;
}
public void setPitch(Float pitch) {
this.pitch = pitch;
}
public Boolean isEnableWordTimestamp() {
return enableWordTimestamp;
}
public void setEnableWordTimestamp(Boolean enableWordTimestamp) {
this.enableWordTimestamp = enableWordTimestamp;
}
public Boolean isEnablePhonemeTimestamp() {
return enablePhonemeTimestamp;
}
public void setEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
this.enablePhonemeTimestamp = enablePhonemeTimestamp;
}
/**
* Build a options instances.
*/
public static class Builder {
private final TongYiAudioSpeechOptions options = new TongYiAudioSpeechOptions();
public Builder withModel(String model) {
options.model = model;
return this;
}
public Builder withText(String text) {
options.text = text;
return this;
}
public Builder withTextType(SpeechSynthesisTextType textType) {
options.textType = textType;
return this;
}
public Builder withFormat(SpeechSynthesisAudioFormat format) {
options.format = format;
return this;
}
public Builder withSampleRate(Integer sampleRate) {
options.sampleRate = sampleRate;
return this;
}
public Builder withVolume(Integer volume) {
options.volume = volume;
return this;
}
public Builder withRate(Float rate) {
options.rate = rate;
return this;
}
public Builder withPitch(Float pitch) {
options.pitch = pitch;
return this;
}
public Builder withEnableWordTimestamp(Boolean enableWordTimestamp) {
options.enableWordTimestamp = enableWordTimestamp;
return this;
}
public Builder withEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
options.enablePhonemeTimestamp = enablePhonemeTimestamp;
return this;
}
public TongYiAudioSpeechOptions build() {
return options;
}
}
}

View File

@@ -1,77 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* TongYi audio speech configuration properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiAudioSpeechProperties.CONFIG_PREFIX)
public class TongYiAudioSpeechProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.speech";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioSpeechModels.SAMBERT_ZHICHU_V1;
/**
* Enable TongYiQWEN ai audio client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiAudioSpeechOptions options = TongYiAudioSpeechOptions.builder()
.withModel(DEFAULT_AUDIO_MODEL_NAME)
.withFormat(SpeechSynthesisAudioFormat.WAV)
.build();
public TongYiAudioSpeechOptions getOptions() {
return this.options;
}
public void setOptions(TongYiAudioSpeechOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@@ -1,87 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.ModelResult;
import org.springframework.lang.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class Speech implements ModelResult<ByteBuffer> {
private final ByteBuffer audio;
private SpeechMetadata speechMetadata;
public Speech(ByteBuffer audio) {
this.audio = audio;
}
@Override
public ByteBuffer getOutput() {
return this.audio;
}
@Override
public SpeechMetadata getMetadata() {
return speechMetadata != null ? speechMetadata : SpeechMetadata.NULL;
}
public Speech withSpeechMetadata(@Nullable SpeechMetadata speechMetadata) {
this.speechMetadata = speechMetadata;
return this;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Speech that)) {
return false;
}
return Arrays.equals(audio.array(), that.audio.array())
&& Objects.equals(speechMetadata, that.speechMetadata);
}
@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(audio.array()), speechMetadata);
}
@Override
public String toString() {
return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}';
}
}

View File

@@ -1,80 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import java.util.Objects;
/**
* The {@link SpeechMessage} class represents a single text message to
* be converted to speech by the TongYi LLM TTS.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechMessage {
private String text;
/**
* Constructs a new {@link SpeechMessage} object with the given text.
* @param text the text to be converted to speech
*/
public SpeechMessage(String text) {
this.text = text;
}
/**
* Returns the text of this speech message.
* @return the text of this speech message
*/
public String getText() {
return text;
}
/**
* Sets the text of this speech message.
* @param text the new text for this speech message
*/
public void setText(String text) {
this.text = text;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechMessage that)) {
return false;
}
return Objects.equals(text, that.text);
}
@Override
public int hashCode() {
return Objects.hash(text);
}
}

View File

@@ -1,43 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.ResultMetadata;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public interface SpeechMetadata extends ResultMetadata {
/**
* Null Object.
*/
SpeechMetadata NULL = SpeechMetadata.create();
/**
* Factory method used to construct a new {@link SpeechMetadata}.
* @return a new {@link SpeechMetadata}
*/
static SpeechMetadata create() {
return new SpeechMetadata() {
};
}
}

View File

@@ -1,51 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.Model;
import java.nio.ByteBuffer;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.0.0-RC1
*/
@FunctionalInterface
public interface SpeechModel extends Model<SpeechPrompt, SpeechResponse> {
/**
* Generates spoken audio from the provided text message.
* @param message the text message to be converted to audio.
* @return the resulting audio bytes.
*/
default ByteBuffer call(String message) {
SpeechPrompt prompt = new SpeechPrompt(message);
return call(prompt).getResult().getOutput();
}
/**
* Sends a speech request to the TongYi TTS API and returns the resulting speech response.
* @param request the speech prompt containing the input text and other parameters.
* @return the speech response containing the generated audio.
*/
SpeechResponse call(SpeechPrompt request);
}

View File

@@ -1,89 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechOptions;
import org.springframework.ai.model.ModelRequest;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechPrompt implements ModelRequest<SpeechMessage> {
private TongYiAudioSpeechOptions speechOptions;
private final SpeechMessage message;
public SpeechPrompt(String instructions) {
this(new SpeechMessage(instructions), TongYiAudioSpeechOptions.builder().build());
}
public SpeechPrompt(String instructions, TongYiAudioSpeechOptions speechOptions) {
this(new SpeechMessage(instructions), speechOptions);
}
public SpeechPrompt(SpeechMessage speechMessage) {
this(speechMessage, TongYiAudioSpeechOptions.builder().build());
}
public SpeechPrompt(SpeechMessage speechMessage, TongYiAudioSpeechOptions speechOptions) {
this.message = speechMessage;
this.speechOptions = speechOptions;
}
@Override
public SpeechMessage getInstructions() {
return this.message;
}
@Override
public TongYiAudioSpeechOptions getOptions() {
return speechOptions;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechPrompt that)) {
return false;
}
return Objects.equals(speechOptions, that.speechOptions) && Objects.equals(message, that.message);
}
@Override
public int hashCode() {
return Objects.hash(speechOptions, message);
}
}

View File

@@ -1,100 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechResponse implements ModelResponse<Speech> {
private final Speech speech;
private final TongYiAudioSpeechResponseMetadata speechResponseMetadata;
/**
* Creates a new instance of SpeechResponse with the given speech result.
* @param speech the speech result to be set in the SpeechResponse
* @see Speech
*/
public SpeechResponse(Speech speech) {
this(speech, TongYiAudioSpeechResponseMetadata.NULL);
}
/**
* Creates a new instance of SpeechResponse with the given speech result and speech
* response metadata.
* @param speech the speech result to be set in the SpeechResponse
* @param speechResponseMetadata the speech response metadata to be set in the
* SpeechResponse
* @see Speech
* @see TongYiAudioSpeechResponseMetadata
*/
public SpeechResponse(Speech speech, TongYiAudioSpeechResponseMetadata speechResponseMetadata) {
this.speech = speech;
this.speechResponseMetadata = speechResponseMetadata;
}
@Override
public Speech getResult() {
return speech;
}
@Override
public List<Speech> getResults() {
return Collections.singletonList(speech);
}
@Override
public TongYiAudioSpeechResponseMetadata getMetadata() {
return speechResponseMetadata;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechResponse that)) {
return false;
}
return Objects.equals(speech, that.speech)
&& Objects.equals(speechResponseMetadata, that.speechResponseMetadata);
}
@Override
public int hashCode() {
return Objects.hash(speech, speechResponseMetadata);
}
}

View File

@@ -1,54 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.StreamingModel;
import reactor.core.publisher.Flux;
import java.nio.ByteBuffer;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@FunctionalInterface
public interface SpeechStreamModel extends StreamingModel<SpeechPrompt, SpeechResponse> {
/**
* Generates a stream of audio bytes from the provided text message.
*
* @param message the text message to be converted to audio
* @return a Flux of audio bytes representing the generated speech
*/
default Flux<ByteBuffer> stream(String message) {
SpeechPrompt prompt = new SpeechPrompt(message);
return stream(prompt).map(SpeechResponse::getResult).map(Speech::getOutput);
}
/**
* Sends a speech request to the TongYi TTS API and returns a stream of the resulting
* speech responses.
* @param prompt the speech prompt containing the input text and other parameters.
* @return a Flux of speech responses, each containing a portion of the generated audio.
*/
@Override
Flux<SpeechResponse> stream(SpeechPrompt prompt);
}

View File

@@ -1,187 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionPrompt;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResponse;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResult;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
import com.alibaba.dashscope.audio.asr.transcription.*;
import org.springframework.ai.model.Model;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* TongYiAudioTranscriptionModel is a client for TongYi audio transcription service for
* Spring Cloud Alibaba AI.
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionModel
implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
/**
* TongYi models options.
*/
private final TongYiAudioTranscriptionOptions defaultOptions;
/**
* TongYi models api.
*/
private final Transcription transcription;
public TongYiAudioTranscriptionModel(Transcription transcription) {
this(null, transcription);
}
public TongYiAudioTranscriptionModel(TongYiAudioTranscriptionOptions defaultOptions,
Transcription transcription) {
Assert.notNull(transcription, "transcription must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
this.defaultOptions = defaultOptions;
this.transcription = transcription;
}
@Override
public AudioTranscriptionResponse call(AudioTranscriptionPrompt prompt) {
TranscriptionParam transcriptionParam;
if (prompt.getOptions() != null) {
var param = merge(prompt.getOptions());
transcriptionParam = toTranscriptionParam(param);
transcriptionParam.setFileUrls(prompt.getOptions().getFileUrls());
}
else {
Resource instructions = prompt.getInstructions();
try {
transcriptionParam = TranscriptionParam.builder()
.model(AudioTranscriptionModels.Paraformer_V1)
.fileUrls(ListUtil.of(String.valueOf(instructions.getURL())))
.build();
}
catch (IOException e) {
throw new TongYiException("Failed to create resource", e);
}
}
List<TranscriptionTaskResult> taskResultList;
try {
// Submit a transcription request
TranscriptionResult result = transcription.asyncCall(transcriptionParam);
// Wait for the transcription to complete
result = transcription.wait(TranscriptionQueryParam
.FromTranscriptionParam(transcriptionParam, result.getTaskId()));
// Get the transcription results
System.out.println(result.getOutput().getAsJsonObject());
taskResultList = result.getResults();
System.out.println(Arrays.toString(taskResultList.toArray()));
return new AudioTranscriptionResponse(
taskResultList.stream().map(taskResult ->
new AudioTranscriptionResult(taskResult.getTranscriptionUrl())
).collect(Collectors.toList()),
TongYiAudioTranscriptionResponseMetadata.from(result)
);
}
catch (Exception e) {
throw new TongYiException("Failed to call audio transcription", e);
}
}
public TongYiAudioTranscriptionOptions merge(TongYiAudioTranscriptionOptions target) {
var mergeBuilder = TongYiAudioTranscriptionOptions.builder();
mergeBuilder
.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel()
: target.getModel());
mergeBuilder.withChannelId(
defaultOptions.getChannelId() != null ? defaultOptions.getChannelId()
: target.getChannelId());
mergeBuilder.withDiarizationEnabled(defaultOptions.getDiarizationEnabled() != null
? defaultOptions.getDiarizationEnabled()
: target.getDiarizationEnabled());
mergeBuilder.withDisfluencyRemovalEnabled(
defaultOptions.getDisfluencyRemovalEnabled() != null
? defaultOptions.getDisfluencyRemovalEnabled()
: target.getDisfluencyRemovalEnabled());
mergeBuilder.withTimestampAlignmentEnabled(
defaultOptions.getTimestampAlignmentEnabled() != null
? defaultOptions.getTimestampAlignmentEnabled()
: target.getTimestampAlignmentEnabled());
mergeBuilder.withSpecialWordFilter(defaultOptions.getSpecialWordFilter() != null
? defaultOptions.getSpecialWordFilter()
: target.getSpecialWordFilter());
mergeBuilder.withAudioEventDetectionEnabled(
defaultOptions.getAudioEventDetectionEnabled() != null
? defaultOptions.getAudioEventDetectionEnabled()
: target.getAudioEventDetectionEnabled());
return mergeBuilder.build();
}
public TranscriptionParam toTranscriptionParam(
TongYiAudioTranscriptionOptions source) {
var mergeBuilder = TranscriptionParam.builder();
mergeBuilder.model(source.getModel() != null ? source.getModel()
: AudioTranscriptionModels.Paraformer_V1);
mergeBuilder.fileUrls(
source.getFileUrls() != null ? source.getFileUrls() : new ArrayList<>());
if (source.getPhraseId() != null) {
mergeBuilder.phraseId(source.getPhraseId());
}
if (source.getChannelId() != null) {
mergeBuilder.channelId(source.getChannelId());
}
if (source.getDiarizationEnabled() != null) {
mergeBuilder.diarizationEnabled(source.getDiarizationEnabled());
}
if (source.getSpeakerCount() != null) {
mergeBuilder.speakerCount(source.getSpeakerCount());
}
if (source.getDisfluencyRemovalEnabled() != null) {
mergeBuilder.disfluencyRemovalEnabled(source.getDisfluencyRemovalEnabled());
}
if (source.getTimestampAlignmentEnabled() != null) {
mergeBuilder.timestampAlignmentEnabled(source.getTimestampAlignmentEnabled());
}
if (source.getSpecialWordFilter() != null) {
mergeBuilder.specialWordFilter(source.getSpecialWordFilter());
}
if (source.getAudioEventDetectionEnabled() != null) {
mergeBuilder
.audioEventDetectionEnabled(source.getAudioEventDetectionEnabled());
}
return mergeBuilder.build();
}
}

View File

@@ -1,203 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import org.springframework.ai.model.ModelOptions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionOptions implements ModelOptions {
private String model = AudioTranscriptionModels.Paraformer_V1;
private List<String> fileUrls = new ArrayList<>();
private String phraseId = null;
private List<Integer> channelId = Collections.singletonList(0);
private Boolean diarizationEnabled = false;
private Integer speakerCount = null;
private Boolean disfluencyRemovalEnabled = false;
private Boolean timestampAlignmentEnabled = false;
private String specialWordFilter = "";
private Boolean audioEventDetectionEnabled = false;
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<String> getFileUrls() {
return fileUrls;
}
public void setFileUrls(List<String> fileUrls) {
this.fileUrls = fileUrls;
}
public String getPhraseId() {
return phraseId;
}
public void setPhraseId(String phraseId) {
this.phraseId = phraseId;
}
public List<Integer> getChannelId() {
return channelId;
}
public void setChannelId(List<Integer> channelId) {
this.channelId = channelId;
}
public Boolean getDiarizationEnabled() {
return diarizationEnabled;
}
public void setDiarizationEnabled(Boolean diarizationEnabled) {
this.diarizationEnabled = diarizationEnabled;
}
public Integer getSpeakerCount() {
return speakerCount;
}
public void setSpeakerCount(Integer speakerCount) {
this.speakerCount = speakerCount;
}
public Boolean getDisfluencyRemovalEnabled() {
return disfluencyRemovalEnabled;
}
public void setDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
this.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
}
public Boolean getTimestampAlignmentEnabled() {
return timestampAlignmentEnabled;
}
public void setTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
this.timestampAlignmentEnabled = timestampAlignmentEnabled;
}
public String getSpecialWordFilter() {
return specialWordFilter;
}
public void setSpecialWordFilter(String specialWordFilter) {
this.specialWordFilter = specialWordFilter;
}
public Boolean getAudioEventDetectionEnabled() {
return audioEventDetectionEnabled;
}
public void setAudioEventDetectionEnabled(Boolean audioEventDetectionEnabled) {
this.audioEventDetectionEnabled = audioEventDetectionEnabled;
}
/**
* Builder class for constructing TongYiAudioTranscriptionOptions instances.
*/
public static class Builder {
private final TongYiAudioTranscriptionOptions options = new TongYiAudioTranscriptionOptions();
public Builder withModel(String model) {
options.model = model;
return this;
}
public Builder withFileUrls(List<String> fileUrls) {
options.fileUrls = fileUrls;
return this;
}
public Builder withPhraseId(String phraseId) {
options.phraseId = phraseId;
return this;
}
public Builder withChannelId(List<Integer> channelId) {
options.channelId = channelId;
return this;
}
public Builder withDiarizationEnabled(Boolean diarizationEnabled) {
options.diarizationEnabled = diarizationEnabled;
return this;
}
public Builder withSpeakerCount(Integer speakerCount) {
options.speakerCount = speakerCount;
return this;
}
public Builder withDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
options.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
return this;
}
public Builder withTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
options.timestampAlignmentEnabled = timestampAlignmentEnabled;
return this;
}
public Builder withSpecialWordFilter(String specialWordFilter) {
options.specialWordFilter = specialWordFilter;
return this;
}
public Builder withAudioEventDetectionEnabled(
Boolean audioEventDetectionEnabled) {
options.audioEventDetectionEnabled = audioEventDetectionEnabled;
return this;
}
public TongYiAudioTranscriptionOptions build() {
// Perform any necessary validation here before returning the built object
return options;
}
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiAudioTranscriptionProperties.CONFIG_PREFIX)
public class TongYiAudioTranscriptionProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.transcription";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioTranscriptionModels.Paraformer_V1;
/**
* Enable TongYiQWEN ai audio client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiAudioTranscriptionOptions options = TongYiAudioTranscriptionOptions
.builder().withModel(DEFAULT_AUDIO_MODEL_NAME).build();
public TongYiAudioTranscriptionOptions getOptions() {
return this.options;
}
public void setOptions(TongYiAudioTranscriptionOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@@ -1,56 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionOptions;
import org.springframework.ai.model.ModelRequest;
import org.springframework.core.io.Resource;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class AudioTranscriptionPrompt implements ModelRequest<Resource> {
private Resource audioResource;
private TongYiAudioTranscriptionOptions transcriptionOptions;
public AudioTranscriptionPrompt(Resource resource) {
this.audioResource = resource;
}
public AudioTranscriptionPrompt(Resource resource, TongYiAudioTranscriptionOptions transcriptionOptions) {
this.audioResource = resource;
this.transcriptionOptions = transcriptionOptions;
}
@Override
public Resource getInstructions() {
return audioResource;
}
@Override
public TongYiAudioTranscriptionOptions getOptions() {
return transcriptionOptions;
}
}

View File

@@ -1,67 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import org.springframework.ai.model.ResponseMetadata;
import java.util.List;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class AudioTranscriptionResponse implements ModelResponse<AudioTranscriptionResult> {
private List<AudioTranscriptionResult> resultList;
private TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata;
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result) {
this(result, TongYiAudioTranscriptionResponseMetadata.NULL);
}
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result,
TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata) {
this.resultList = List.copyOf(result);
this.transcriptionResponseMetadata = transcriptionResponseMetadata;
}
@Override
public AudioTranscriptionResult getResult() {
return this.resultList.get(0);
}
@Override
public List<AudioTranscriptionResult> getResults() {
return this.resultList;
}
@Override
public ResponseMetadata getMetadata() {
return this.transcriptionResponseMetadata;
}
}

View File

@@ -1,68 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionMetadata;
import org.springframework.ai.model.ModelResult;
import org.springframework.ai.model.ResultMetadata;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class AudioTranscriptionResult implements ModelResult<String> {
private String text;
private TongYiAudioTranscriptionMetadata transcriptionMetadata;
public AudioTranscriptionResult(String text) {
this.text = text;
}
@Override
public String getOutput() {
return this.text;
}
@Override
public ResultMetadata getMetadata() {
return transcriptionMetadata != null ? transcriptionMetadata : TongYiAudioTranscriptionMetadata.NULL;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AudioTranscriptionResult that = (AudioTranscriptionResult) o;
return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata);
}
@Override
public int hashCode() {
return Objects.hash(text, transcriptionMetadata);
}
}

View File

@@ -1,483 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.tools.FunctionDefinition;
import com.alibaba.dashscope.tools.ToolCallBase;
import com.alibaba.dashscope.tools.ToolCallFunction;
import com.alibaba.dashscope.utils.ApiKeywords;
import com.alibaba.dashscope.utils.JsonUtils;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
* backed by {@link Generation}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
* @see ChatModel
* @see com.alibaba.dashscope.aigc.generation
*/
public class TongYiChatModel extends
AbstractFunctionCallSupport<
com.alibaba.dashscope.common.Message,
ConversationParam,
GenerationResult>
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(TongYiChatModel.class);
/**
* DashScope generation client.
*/
private final Generation generation;
/**
* The TongYi models default chat completion api.
*/
private TongYiChatOptions defaultOptions;
/**
* User role message manager.
*/
@Autowired
private MessageManager msgManager;
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
*/
public TongYiChatModel(Generation generation) {
this(generation,
TongYiChatOptions.builder()
.withTopP(0.8)
.withEnableSearch(true)
.withResultFormat(ConversationParam.ResultFormat.MESSAGE)
.build(),
null
);
}
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
* @param options TongYi model params.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options) {
this(generation, options, null);
}
/**
* Create a TongYi models client.
* @param generation DashScope model generation client.
* @param options TongYi default chat completion api.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
this.generation = generation;
this.defaultOptions = options;
}
/**
* Get default sca chat options.
*
* @return TongYiChatOptions default object.
*/
public TongYiChatOptions getDefaultOptions() {
return this.defaultOptions;
}
@Override
public ChatResponse call(Prompt prompt) {
ConversationParam params = toTongYiChatParams(prompt);
// TongYi models context loader.
com.alibaba.dashscope.common.Message message = new com.alibaba.dashscope.common.Message();
message.setRole(Role.USER.getValue());
message.setContent(prompt.getContents());
msgManager.add(message);
params.setMessages(msgManager.get());
logger.trace("TongYi ConversationOptions: {}", params);
GenerationResult chatCompletions = this.callWithFunctionSupport(params);
logger.trace("TongYi ConversationOptions: {}", params);
msgManager.add(chatCompletions);
List<org.springframework.ai.chat.model.Generation> generations =
chatCompletions
.getOutput()
.getChoices()
.stream()
.map(choice ->
new org.springframework.ai.chat.model.Generation(
choice
.getMessage()
.getContent()
).withGenerationMetadata(generateChoiceMetadata(choice)
))
.toList();
return new ChatResponse(generations);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
Flowable<GenerationResult> genRes;
ConversationParam tongYiChatParams = toTongYiChatParams(prompt);
// See https://help.aliyun.com/zh/dashscope/developer-reference/api-details?spm=a2c4g.11186623.0.0.655fc11aRR0jj7#b9ad0a10cfhpe
// tongYiChatParams.setIncrementalOutput(true);
try {
genRes = generation.streamCall(tongYiChatParams);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes)
.flatMap(
message -> Flux.just(
message.getOutput()
.getChoices()
.get(0)
.getMessage()
.getContent())
.map(content -> {
var gen = new org.springframework.ai.chat.model.Generation(content)
.withGenerationMetadata(generateChoiceMetadata(
message.getOutput()
.getChoices()
.get(0)
));
return new ChatResponse(ListUtil.of(gen));
})
)
.publishOn(Schedulers.parallel());
}
/**
* Configuration properties to Qwen model params.
* Test access.
*
* @param prompt {@link Prompt}
* @return Qwen models params {@link ConversationParam}
*/
public ConversationParam toTongYiChatParams(Prompt prompt) {
Set<String> functionsForThisRequest = new HashSet<>();
List<com.alibaba.dashscope.common.Message> tongYiMessage = prompt.getInstructions().stream()
.map(this::fromSpringAIMessage)
.toList();
ConversationParam chatParams = ConversationParam.builder()
.messages(tongYiMessage)
// models setting
// {@link HalfDuplexServiceParam#models}
.model(Generation.Models.QWEN_TURBO)
// {@link GenerationOutput}
.resultFormat(ConversationParam.ResultFormat.MESSAGE)
.incrementalOutput(true)
.build();
if (this.defaultOptions != null) {
chatParams = merge(chatParams, this.defaultOptions);
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, !IS_RUNTIME_CALL);
functionsForThisRequest.addAll(defaultEnabledFunctions);
}
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
TongYiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, TongYiChatOptions.class);
chatParams = merge(updatedRuntimeOptions, chatParams);
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ConversationParam:"
+ prompt.getOptions().getClass().getSimpleName());
}
}
// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<FunctionDefinition> tools = this.getFunctionTools(functionsForThisRequest);
// todo chatParams.setTools(tools)
}
return chatParams;
}
private ChatGenerationMetadata generateChoiceMetadata(GenerationOutput.Choice choice) {
return ChatGenerationMetadata.from(
String.valueOf(choice.getFinishReason()),
choice.getMessage().getContent()
);
}
private List<FunctionDefinition> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
FunctionDefinition functionDefinition = FunctionDefinition.builder()
.name(functionCallback.getName())
.description(functionCallback.getDescription())
.parameters(JsonUtils.parametersToJsonObject(
ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())
))
.build();
return functionDefinition;
}).toList();
}
private ConversationParam merge(ConversationParam tongYiParams, TongYiChatOptions scaChatParams) {
if (scaChatParams == null) {
return tongYiParams;
}
return ConversationParam.builder()
.messages(tongYiParams.getMessages())
.maxTokens((tongYiParams.getMaxTokens() != null) ? tongYiParams.getMaxTokens() : scaChatParams.getMaxTokens())
// When merge options. Because ConversationParams is must not null. So is setting.
.model(scaChatParams.getModel())
.resultFormat((tongYiParams.getResultFormat() != null) ? tongYiParams.getResultFormat() : scaChatParams.getResultFormat())
.enableSearch((tongYiParams.getEnableSearch() != null) ? tongYiParams.getEnableSearch() : scaChatParams.getEnableSearch())
.topK((tongYiParams.getTopK() != null) ? tongYiParams.getTopK() : scaChatParams.getTopK())
.topP((tongYiParams.getTopP() != null) ? tongYiParams.getTopP() : scaChatParams.getTopP())
.incrementalOutput((tongYiParams.getIncrementalOutput() != null) ? tongYiParams.getIncrementalOutput() : scaChatParams.getIncrementalOutput())
.temperature((tongYiParams.getTemperature() != null) ? tongYiParams.getTemperature() : scaChatParams.getTemperature())
.repetitionPenalty((tongYiParams.getRepetitionPenalty() != null) ? tongYiParams.getRepetitionPenalty() : scaChatParams.getRepetitionPenalty())
.seed((tongYiParams.getSeed() != null) ? tongYiParams.getSeed() : scaChatParams.getSeed())
.build();
}
private ConversationParam merge(TongYiChatOptions scaChatParams, ConversationParam tongYiParams) {
if (scaChatParams == null) {
return tongYiParams;
}
ConversationParam mergedTongYiParams = ConversationParam.builder()
.model(Generation.Models.QWEN_TURBO)
.messages(tongYiParams.getMessages())
.build();
mergedTongYiParams = merge(tongYiParams, scaChatParams);
if (scaChatParams.getMaxTokens() != null) {
mergedTongYiParams.setMaxTokens(scaChatParams.getMaxTokens());
}
if (scaChatParams.getStop() != null) {
mergedTongYiParams.setStopStrings(scaChatParams.getStop());
}
if (scaChatParams.getTemperature() != null) {
mergedTongYiParams.setTemperature(scaChatParams.getTemperature());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
return mergedTongYiParams;
}
private com.alibaba.dashscope.common.Message fromSpringAIMessage(Message message) {
return switch (message.getMessageType()) {
case USER -> com.alibaba.dashscope.common.Message.builder()
.role(Role.USER.getValue())
.content(message.getContent())
.build();
case SYSTEM -> com.alibaba.dashscope.common.Message.builder()
.role(Role.SYSTEM.getValue())
.content(message.getContent())
.build();
case ASSISTANT -> com.alibaba.dashscope.common.Message.builder()
.role(Role.ASSISTANT.getValue())
.content(message.getContent())
.build();
default -> throw new IllegalArgumentException("Unknown message type " + message.getMessageType());
};
}
@Override
protected ConversationParam doCreateToolResponseRequest(
ConversationParam previousRequest,
com.alibaba.dashscope.common.Message responseMessage,
List<com.alibaba.dashscope.common.Message> conversationHistory
) {
for (ToolCallBase toolCall : responseMessage.getToolCalls()) {
if (toolCall instanceof ToolCallFunction toolCallFunction) {
if (toolCallFunction.getFunction() != null) {
var functionName = toolCallFunction.getFunction().getName();
var functionArguments = toolCallFunction.getFunction().getArguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
// Add the function response to the conversation.
conversationHistory
.add(com.alibaba.dashscope.common.Message.builder()
.content(functionResponse)
.role(Role.BOT.getValue())
.toolCallId(toolCall.getId())
.build()
);
}
}
}
ConversationParam newRequest = ConversationParam.builder().messages(conversationHistory).build();
// todo: No @JsonProperty fields.
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ConversationParam.class);
return newRequest;
}
@Override
protected List<com.alibaba.dashscope.common.Message> doGetUserMessages(ConversationParam request) {
return request.getMessages();
}
@Override
protected com.alibaba.dashscope.common.Message doGetToolResponseMessage(GenerationResult response) {
var message = response.getOutput().getChoices().get(0).getMessage();
var assistantMessage = com.alibaba.dashscope.common.Message.builder().role(Role.ASSISTANT.getValue())
.content("").build();
assistantMessage.setToolCalls(message.getToolCalls());
return assistantMessage;
}
@Override
protected GenerationResult doChatCompletion(ConversationParam request) {
GenerationResult result;
try {
result = generation.call(request);
}
catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
return result;
}
@Override
protected Flux<GenerationResult> doChatCompletionStream(ConversationParam request) {
final Flowable<GenerationResult> genRes;
try {
genRes = generation.streamCall(request);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes);
}
@Override
protected boolean isToolFunctionCall(GenerationResult response) {
if (response == null || CollectionUtils.isEmpty(response.getOutput().getChoices())) {
return false;
}
var choice = response.getOutput().getChoices().get(0);
if (choice == null || choice.getFinishReason() == null) {
return false;
}
return Objects.equals(choice.getFinishReason(), ApiKeywords.TOOL_CALLS);
}
}

View File

@@ -1,463 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
import java.util.*;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiChatOptions implements FunctionCallingOptions, ChatOptions {
/**
* TongYi Models.
* {@link Generation.Models}
*/
private String model = Generation.Models.QWEN_TURBO;
/**
* The random number seed used in generation, the user controls the randomness of the content generated by the model.
* seed supports unsigned 64-bit integers, with a default value of 1234.
* when using seed, the model will generate the same or similar results as much as possible, but there is currently no guarantee that the results will be exactly the same each time.
*/
private Integer seed = 1234;
/**
* Used to specify the maximum number of tokens that the model can generate when generating content,
* it defines the upper limit of generation but does not guarantee that this number will be generated every time.
* For qwen-turbo the maximum and default values are 1500 tokens.
* The qwen-max, qwen-max-1201, qwen-max-longcontext, and qwen-plus models have a maximum and default value of 2000 tokens.
*/
private Integer maxTokens = 1500;
/**
* The generation process kernel sampling method probability threshold,
* for example, takes the value of 0.8, only retains the smallest set of the most probable tokens with probabilities that add up to greater than or equal to 0.8 as the candidate set.
* The range of values is (0,1.0), the larger the value, the higher the randomness of generation; the lower the value, the higher the certainty of generation.
*/
private Double topP = 0.8;
/**
* The size of the sampling candidate set at the time of generation.
* For example, with a value of 50, only the 50 highest scoring tokens in a single generation will form a randomly sampled candidate set.
* The larger the value, the higher the randomness of the generation; the smaller the value, the higher the certainty of the generation.
* This parameter is not passed by default, and a value of None or when top_k is greater than 100 indicates that the top_k policy is not enabled,
* at which time, only the top_p policy is in effect.
*/
private Integer topK;
/**
* Used to control the repeatability of model generation.
* Increasing repetition_penalty reduces the repetition of model generation. 1.0 means no penalty.
*/
private Double repetitionPenalty = 1.1;
/**
* is used to control the degree of randomness and diversity.
* Specifically, the temperature value controls the extent to which the probability distribution of each candidate word is smoothed when generating text.
* Higher values of temperature reduce the peak of the probability distribution, allowing more low-probability words to be selected and generating more diverse results,
* while lower values of temperature enhance the peak of the probability distribution, making it easier for high-probability words to be selected and generating more certain results.
* Range: [0, 2), 0 is not recommended, meaningless.
* java version >= 2.5.1
*/
private Double temperature = 0.85;
/**
* The stop parameter is used to realize precise control of the content generation process, automatically stopping when the generated content is about to contain the specified string or token_ids,
* and the generated content does not contain the specified content.
* For example, if stop is specified as "Hello", it means stop when "Hello" will be generated; if stop is specified as [37763, 367], it means stop when "Observation" will be generated.
* The stop parameter can be passed as a list of arrays of strings or token_ids to support the scenario of using multiple stops.
* Explanation: Do not mix strings and token_ids in list mode, the element types should be the same in list mode.
*/
private List<String> stop;
/**
* Whether or not to use stream output. When outputting the result in stream mode, the interface returns the result as generator,
* you need to iterate to get the result, the default output is the whole sequence of the current generation for each output,
* the last output is the final result of all the generation, you can change the output mode to non-incremental output by the parameter incremental_output to False.
*/
private Boolean stream = false;
/**
* The model has a built-in Internet search service.
* This parameter controls whether the model refers to the use of Internet search results when generating text. The values are as follows:
* True: enable internet search, the model will use the search result as the reference information in the text generation process, but the model will "judge by itself" whether to use the internet search result based on its internal logic.
* False (default): Internet search is disabled.
*/
private Boolean enableSearch = false;
/**
* [text|message], defaults to text, when it is message,
* the output refers to the message result example.
* It is recommended to prioritize the use of message format.
*/
private String resultFormat = GenerationParam.ResultFormat.MESSAGE;
/**
* Control the streaming output mode, that is, the content will contain the content has been output;
* set to True, will open the incremental output mode, the output will not contain the content has been output,
* you need to splice the whole output, refer to the streaming output sample code.
*/
private Boolean incrementalOutput = false;
/**
* A list of tools that the model can optionally call.
* Currently only functions are supported, and even if multiple functions are entered, the model will only select one to generate the result.
*/
private List<String> tools;
@Override
public Float getTemperature() {
return this.temperature.floatValue();
}
public void setTemperature(Float temperature) {
this.temperature = temperature.doubleValue();
}
@Override
public Float getTopP() {
return this.topP.floatValue();
}
public void setTopP(Float topP) {
this.topP = topP.doubleValue();
}
@Override
public Integer getTopK() {
return this.topK;
}
public void setTopK(Integer topK) {
this.topK = topK;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public Integer getSeed() {
return seed;
}
public String getResultFormat() {
return resultFormat;
}
public void setResultFormat(String resultFormat) {
this.resultFormat = resultFormat;
}
public void setSeed(Integer seed) {
this.seed = seed;
}
public Integer getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
public Float getRepetitionPenalty() {
return repetitionPenalty.floatValue();
}
public void setRepetitionPenalty(Float repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty.doubleValue();
}
public List<String> getStop() {
return stop;
}
public void setStop(List<String> stop) {
this.stop = stop;
}
public Boolean getStream() {
return stream;
}
public void setStream(Boolean stream) {
this.stream = stream;
}
public Boolean getEnableSearch() {
return enableSearch;
}
public void setEnableSearch(Boolean enableSearch) {
this.enableSearch = enableSearch;
}
public Boolean getIncrementalOutput() {
return incrementalOutput;
}
public void setIncrementalOutput(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}
public List<String> getTools() {
return tools;
}
public void setTools(List<String> tools) {
this.tools = tools;
}
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private Set<String> functions = new HashSet<>();
@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
}
@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.functionCallbacks = functionCallbacks;
}
@Override
public Set<String> getFunctions() {
return this.functions;
}
@Override
public void setFunctions(Set<String> functions) {
this.functions = functions;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiChatOptions that = (TongYiChatOptions) o;
return Objects.equals(model, that.model)
&& Objects.equals(seed, that.seed)
&& Objects.equals(maxTokens, that.maxTokens)
&& Objects.equals(topP, that.topP)
&& Objects.equals(topK, that.topK)
&& Objects.equals(repetitionPenalty, that.repetitionPenalty)
&& Objects.equals(temperature, that.temperature)
&& Objects.equals(stop, that.stop)
&& Objects.equals(stream, that.stream)
&& Objects.equals(enableSearch, that.enableSearch)
&& Objects.equals(resultFormat, that.resultFormat)
&& Objects.equals(incrementalOutput, that.incrementalOutput)
&& Objects.equals(tools, that.tools)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(functions, that.functions);
}
@Override
public int hashCode() {
return Objects.hash(
model,
seed,
maxTokens,
topP,
topK,
repetitionPenalty,
temperature,
stop,
stream,
enableSearch,
resultFormat,
incrementalOutput,
tools,
functionCallbacks,
functions
);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("TongYiChatOptions{");
sb.append(", model='").append(model).append('\'');
sb.append(", seed=").append(seed);
sb.append(", maxTokens=").append(maxTokens);
sb.append(", topP=").append(topP);
sb.append(", topK=").append(topK);
sb.append(", repetitionPenalty=").append(repetitionPenalty);
sb.append(", temperature=").append(temperature);
sb.append(", stop=").append(stop);
sb.append(", stream=").append(stream);
sb.append(", enableSearch=").append(enableSearch);
sb.append(", resultFormat='").append(resultFormat).append('\'');
sb.append(", incrementalOutput=").append(incrementalOutput);
sb.append(", tools=").append(tools);
sb.append(", functionCallbacks=").append(functionCallbacks);
sb.append(", functions=").append(functions);
sb.append('}');
return sb.toString();
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected TongYiChatOptions options;
public Builder() {
this.options = new TongYiChatOptions();
}
public Builder(TongYiChatOptions options) {
this.options = options;
}
public Builder withModel(String model) {
this.options.model = model;
return this;
}
public Builder withMaxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
}
public Builder withResultFormat(String rf) {
this.options.resultFormat = rf;
return this;
}
public Builder withEnableSearch(Boolean enableSearch) {
this.options.enableSearch = enableSearch;
return this;
}
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}
public Builder withFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}
public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}
public Builder withSeed(Integer seed) {
this.options.seed = seed;
return this;
}
public Builder withStop(List<String> stop) {
this.options.stop = stop;
return this;
}
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}
public Builder withTopK(Integer topK) {
this.options.topK = topK;
return this;
}
public Builder withRepetitionPenalty(Double repetitionPenalty) {
this.options.repetitionPenalty = repetitionPenalty;
return this;
}
public TongYiChatOptions build() {
return this.options;
}
}
}

View File

@@ -1,83 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiChatProperties.CONFIG_PREFIX)
public class TongYiChatProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "chat";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_TURBO;
/**
* Default temperature speed.
*/
private static final Double DEFAULT_TEMPERATURE = 0.8;
/**
* Enable TongYiQWEN ai chat client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiChatOptions options = TongYiChatOptions.builder()
.withModel(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.withEnableSearch(true)
.withResultFormat(GenerationParam.ResultFormat.MESSAGE)
.build();
public TongYiChatOptions getOptions() {
return this.options;
}
public void setOptions(TongYiChatOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@@ -1,44 +0,0 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.constants;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
public final class TongYiConstants {
private TongYiConstants() {
}
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String SCA_AI_CONFIGURATION = "spring.cloud.ai.tongyi.";
/**
* Spring Cloud Alibaba AI constants prefix.
*/
public static final String SCA_AI = "SPRING_CLOUD_ALIBABA_";
/**
* TongYi AI apikey env name.
*/
public static final String SCA_AI_TONGYI_API_KEY = SCA_AI + "TONGYI_API_KEY";
}

View File

@@ -1,38 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.exception;
/**
* TongYi models runtime exception.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiException extends RuntimeException {
public TongYiException(String message) {
super(message);
}
public TongYiException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -1,39 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.exception;
/**
* TongYi models images exception.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesException extends TongYiException {
public TongYiImagesException(String message) {
super(message);
}
public TongYiImagesException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@@ -1,84 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import org.springframework.ai.embedding.EmbeddingOptions;
import java.util.List;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public final class TongYiEmbeddingOptions implements EmbeddingOptions {
private List<String> texts;
private TextEmbeddingParam.TextType textType;
public List<String> getTexts() {
return texts;
}
public void setTexts(List<String> texts) {
this.texts = texts;
}
public TextEmbeddingParam.TextType getTextType() {
return textType;
}
public void setTextType(TextEmbeddingParam.TextType textType) {
this.textType = textType;
}
public static Builder builder() {
return new Builder();
}
public final static class Builder {
private final TongYiEmbeddingOptions options;
private Builder() {
this.options = new TongYiEmbeddingOptions();
}
public Builder withtexts(List<String> texts) {
options.setTexts(texts);
return this;
}
public Builder withtextType(TextEmbeddingParam.TextType textType) {
options.setTextType(textType);
return this;
}
public TongYiEmbeddingOptions build() {
return options;
}
}
}

View File

@@ -1,176 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.metadata.TongYiTextEmbeddingResponseMetadata;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
import com.alibaba.dashscope.embeddings.TextEmbeddingResultItem;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;
import java.util.List;
import java.util.stream.Collectors;
/**
* {@link TongYiTextEmbeddingModel} implementation for {@literal Alibaba DashScope}.
*
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public class TongYiTextEmbeddingModel extends AbstractEmbeddingModel {
private final Logger logger = LoggerFactory.getLogger(TongYiTextEmbeddingModel.class);
/**
* TongYi Text Embedding client.
*/
private final TextEmbedding textEmbedding;
/**
* {@link MetadataMode}.
*/
private final MetadataMode metadataMode;
private final TongYiEmbeddingOptions defaultOptions;
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding) {
this(textEmbedding, MetadataMode.EMBED);
}
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding, MetadataMode metadataMode) {
this(textEmbedding, metadataMode,
TongYiEmbeddingOptions.builder()
.withtextType(TextEmbeddingParam.TextType.DOCUMENT)
.build()
);
}
public TongYiTextEmbeddingModel(
TextEmbedding textEmbedding,
MetadataMode metadataMode,
TongYiEmbeddingOptions options
) {
Assert.notNull(textEmbedding, "textEmbedding must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
Assert.notNull(options, "TongYiEmbeddingOptions must not be null");
this.metadataMode = metadataMode;
this.textEmbedding = textEmbedding;
this.defaultOptions = options;
}
public TongYiEmbeddingOptions getDefaultOptions() {
return this.defaultOptions;
}
@Override
public List<Double> embed(Document document) {
return this.call(
new EmbeddingRequest(
ListUtil.of(document.getFormattedContent(this.metadataMode)),
null)
).getResults().stream()
.map(Embedding::getOutput)
.flatMap(List::stream)
.toList();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
TextEmbeddingParam embeddingParams = toEmbeddingParams(request);
logger.debug("Embedding request: {}", embeddingParams);
TextEmbeddingResult resp;
try {
resp = textEmbedding.call(embeddingParams);
}
catch (NoApiKeyException e) {
throw new TongYiException(e.getMessage());
}
return genEmbeddingResp(resp);
}
private EmbeddingResponse genEmbeddingResp(TextEmbeddingResult result) {
return new EmbeddingResponse(
genEmbeddingList(result.getOutput().getEmbeddings()),
TongYiTextEmbeddingResponseMetadata.from(result.getUsage())
);
}
private List<Embedding> genEmbeddingList(List<TextEmbeddingResultItem> embeddings) {
return embeddings.stream()
.map(embedding ->
new Embedding(
embedding.getEmbedding(),
embedding.getTextIndex()
))
.collect(Collectors.toList());
}
/**
* We recommend setting the model parameters by passing the embedding parameters through the code;
* yml configuration compatibility is not considered here.
* It is not recommended that users set parameters from yml,
* as this reduces the flexibility of the configuration.
* Because the model name keeps changing, strings are used here and constants are undefined:
* Model list: <a href="https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-quick-start">Text Embedding Model List</a>
* @param requestOptions Client params. {@link EmbeddingRequest}
* @return {@link TextEmbeddingParam}
*/
private TextEmbeddingParam toEmbeddingParams(EmbeddingRequest requestOptions) {
TextEmbeddingParam tongYiEmbeddingParams = TextEmbeddingParam.builder()
.texts(requestOptions.getInstructions())
.textType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : TextEmbeddingParam.TextType.DOCUMENT)
.model("text-embedding-v1")
.build();
try {
tongYiEmbeddingParams.validate();
}
catch (InputRequiredException e) {
throw new TongYiException(e.getMessage());
}
return tongYiEmbeddingParams;
}
}

View File

@@ -1,50 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import org.springframework.boot.context.properties.ConfigurationProperties;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiTextEmbeddingProperties.CONFIG_PREFIX)
public class TongYiTextEmbeddingProperties {
/**
* Prefix of TongYi Text Embedding properties.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "embedding";
private boolean enabled = true;
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@@ -1,237 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*;
import org.springframework.util.Assert;
import java.io.ByteArrayOutputStream;
import java.net.URL;
import java.util.Base64;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.tongyi.metadata.TongYiImagesResponseMetadata.from;
/**
* TongYiImagesClient is a class that implements the ImageClient interface. It provides a
* client for calling the TongYi AI image generation API.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesModel implements ImageModel {
private final Logger logger = LoggerFactory.getLogger(TongYiImagesModel.class);
/**
* Gen Images API.
*/
private final ImageSynthesis imageSynthesis;
/**
* TongYi Gen images properties.
*/
private TongYiImagesOptions defaultOptions;
/**
* Adapt TongYi images api size properties.
*/
private final String sizeConnection = "*";
/**
* Get default images options.
*
* @return Default TongYiImagesOptions.
*/
public TongYiImagesOptions getDefaultOptions() {
return this.defaultOptions;
}
/**
* TongYiImagesClient constructor.
* @param imageSynthesis the image synthesis
* {@link ImageSynthesis}
*/
public TongYiImagesModel(ImageSynthesis imageSynthesis) {
this(imageSynthesis, TongYiImagesOptions.
builder()
.withModel(ImageSynthesis.Models.WANX_V1)
.withN(1)
.build()
);
}
/**
* TongYiImagesClient constructor.
* @param imageSynthesis {@link ImageSynthesis}
* @param imagesOptions {@link TongYiImagesOptions}
*/
public TongYiImagesModel(ImageSynthesis imageSynthesis, TongYiImagesOptions imagesOptions) {
Assert.notNull(imageSynthesis, "ImageSynthesis must not be null");
Assert.notNull(imagesOptions, "TongYiImagesOptions must not be null");
this.imageSynthesis = imageSynthesis;
this.defaultOptions = imagesOptions;
}
/**
* Call the TongYi images service.
* @param imagePrompt the image prompt.
* @return the image response.
* {@link ImageSynthesis#call(ImageSynthesisParam)}
*/
@Override
public ImageResponse call(ImagePrompt imagePrompt) {
ImageSynthesisResult result;
String prompt = imagePrompt.getInstructions().get(0).getText();
var imgParams = ImageSynthesisParam.builder()
.prompt("")
.model(ImageSynthesis.Models.WANX_V1)
.build();
if (this.defaultOptions != null) {
imgParams = merge(this.defaultOptions);
}
if (imagePrompt.getOptions() != null) {
imgParams = merge(toTingYiImageOptions(imagePrompt.getOptions()));
}
imgParams.setPrompt(prompt);
try {
result = imageSynthesis.call(imgParams);
}
catch (NoApiKeyException e) {
logger.error("TongYi models gen images failed: {}.", e.getMessage());
throw new TongYiImagesException(e.getMessage());
}
return convert(result);
}
public ImageSynthesisParam merge(TongYiImagesOptions target) {
var builder = ImageSynthesisParam.builder();
builder.model(this.defaultOptions.getModel() != null ? this.defaultOptions.getModel() : target.getModel());
builder.n(this.defaultOptions.getN() != null ? this.defaultOptions.getN() : target.getN());
builder.size((this.defaultOptions.getHeight() != null && this.defaultOptions.getWidth() != null)
? this.defaultOptions.getHeight() + "*" + this.defaultOptions.getWidth()
: target.getHeight() + "*" + target.getWidth()
);
// prompt is marked non-null but is null.
builder.prompt("");
return builder.build();
}
private ImageResponse convert(ImageSynthesisResult result) {
return new ImageResponse(
result.getOutput().getResults().stream()
.flatMap(value -> value.entrySet().stream())
.map(entry -> {
String key = entry.getKey();
String value = entry.getValue();
try {
String base64Image = convertImageToBase64(value);
Image image = new Image(value, base64Image);
return new ImageGeneration(image);
}
catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList()),
from(result)
);
}
public TongYiImagesOptions toTingYiImageOptions(ImageOptions runtimeImageOptions) {
var builder = TongYiImagesOptions.builder();
if (runtimeImageOptions != null) {
if (runtimeImageOptions.getN() != null) {
builder.withN(runtimeImageOptions.getN());
}
if (runtimeImageOptions.getModel() != null) {
builder.withModel(runtimeImageOptions.getModel());
}
if (runtimeImageOptions.getHeight() != null) {
builder.withHeight(runtimeImageOptions.getHeight());
}
if (runtimeImageOptions.getWidth() != null) {
builder.withWidth(runtimeImageOptions.getWidth());
}
// todo ImagesParams.
}
return builder.build();
}
/**
* Convert image to base64.
* @param imageUrl the image url.
* @return the base64 image.
* @throws Exception the exception.
*/
public String convertImageToBase64(String imageUrl) throws Exception {
var url = new URL(imageUrl);
var inputStream = url.openStream();
var outputStream = new ByteArrayOutputStream();
var buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
var imageBytes = outputStream.toByteArray();
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
inputStream.close();
outputStream.close();
return base64Image;
}
}

View File

@@ -1,187 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.image.ImageOptions;
import java.util.Objects;
/**
* TongYi Image API options.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesOptions implements ImageOptions {
/**
* Specify the model name, currently only wanx-v1 is supported.
*/
private String model = ImageSynthesis.Models.WANX_V1;
/**
* Gen images number.
*/
private Integer n;
/**
* The width of the generated images.
*/
private Integer width = 1024;
/**
* The height of the generated images.
*/
private Integer height = 1024;
@Override
public Integer getN() {
return this.n;
}
@Override
public String getModel() {
return this.model;
}
@Override
public Integer getWidth() {
return this.width;
}
@Override
public Integer getHeight() {
return this.height;
}
@Override
public String getResponseFormat() {
throw new TongYiImagesException("unimplemented!");
}
public void setModel(String model) {
this.model = model;
}
public void setN(Integer n) {
this.n = n;
}
public void setWidth(Integer width) {
this.width = width;
}
public void setHeight(Integer height) {
this.height = height;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiImagesOptions that = (TongYiImagesOptions) o;
return Objects.equals(model, that.model)
&& Objects.equals(n, that.n)
&& Objects.equals(width, that.width)
&& Objects.equals(height, that.height);
}
@Override
public int hashCode() {
return Objects.hash(model, n, width, height);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("TongYiImagesOptions{");
sb.append("model='").append(model).append('\'');
sb.append(", n=").append(n);
sb.append(", width=").append(width);
sb.append(", height=").append(height);
sb.append('}');
return sb.toString();
}
public static Builder builder() {
return new Builder();
}
public final static class Builder {
private final TongYiImagesOptions options;
private Builder() {
this.options = new TongYiImagesOptions();
}
public Builder withN(Integer n) {
options.setN(n);
return this;
}
public Builder withModel(String model) {
options.setModel(model);
return this;
}
public Builder withWidth(Integer width) {
options.setWidth(width);
return this;
}
public Builder withHeight(Integer height) {
options.setHeight(height);
return this;
}
public TongYiImagesOptions build() {
return options;
}
}
}

View File

@@ -1,77 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* TongYi Image API properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiImagesProperties.CONFIG_PREFIX)
public class TongYiImagesProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "images";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_IMAGES_MODEL_NAME = ImageSynthesis.Models.WANX_V1;
/**
* Enable TongYiQWEN ai images client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiImagesOptions options = TongYiImagesOptions.builder()
.withModel(DEFAULT_IMAGES_MODEL_NAME)
.withN(1)
.build();
public TongYiImagesOptions getOptions() {
return this.options;
}
public void setOptions(TongYiImagesOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@@ -1,89 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal Alibaba DashScope}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
@SuppressWarnings("all")
public static TongYiAiChatResponseMetadata from(GenerationResult chatCompletions,
PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Alibaba ai ChatCompletions must not be null");
String id = chatCompletions.getRequestId();
TongYiAiUsage usage = TongYiAiUsage.from(chatCompletions);
return new TongYiAiChatResponseMetadata(
id,
usage,
promptFilterMetadata
);
}
private final String id;
private final Usage usage;
private final PromptMetadata promptMetadata;
protected TongYiAiChatResponseMetadata(String id, TongYiAiUsage usage, PromptMetadata promptMetadata) {
this.id = id;
this.usage = usage;
this.promptMetadata = promptMetadata;
}
public String getId() {
return this.id;
}
@Override
public Usage getUsage() {
return this.usage;
}
@Override
public PromptMetadata getPromptMetadata() {
return this.promptMetadata;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
}
}

View File

@@ -1,81 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.GenerationUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
/**
* {@link Usage} implementation for {@literal Alibaba DashScope}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAiUsage implements Usage {
private final GenerationUsage usage;
public TongYiAiUsage(GenerationUsage usage) {
Assert.notNull(usage, "GenerationUsage must not be null");
this.usage = usage;
}
public static TongYiAiUsage from(GenerationResult chatCompletions) {
Assert.notNull(chatCompletions, "ChatCompletions must not be null");
return from(chatCompletions.getUsage());
}
public static TongYiAiUsage from(GenerationUsage usage) {
return new TongYiAiUsage(usage);
}
protected GenerationUsage getUsage() {
return this.usage;
}
@Override
public Long getPromptTokens() {
throw new UnsupportedOperationException("Unimplemented method 'getPromptTokens'");
}
@Override
public Long getGenerationTokens() {
return this.getUsage().getOutputTokens().longValue();
}
@Override
public Long getTotalTokens() {
return this.getUsage().getTotalTokens().longValue();
}
@Override
public String toString() {
return this.getUsage().toString();
}
}

View File

@@ -1,131 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisTaskMetrics;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisUsage;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.util.Assert;
import java.util.HashMap;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private final Long created;
private String taskId;
private ImageSynthesisTaskMetrics metrics;
private ImageSynthesisUsage usage;
public static TongYiImagesResponseMetadata from(ImageSynthesisResult synthesisResult) {
Assert.notNull(synthesisResult, "TongYiAiImageResponse must not be null");
return new TongYiImagesResponseMetadata(
System.currentTimeMillis(),
synthesisResult.getOutput().getTaskMetrics(),
synthesisResult.getOutput().getTaskId(),
synthesisResult.getUsage()
);
}
protected TongYiImagesResponseMetadata(
Long created,
ImageSynthesisTaskMetrics metrics,
String taskId,
ImageSynthesisUsage usage
) {
this.taskId = taskId;
this.metrics = metrics;
this.created = created;
this.usage = usage;
}
public ImageSynthesisUsage getUsage() {
return usage;
}
public void setUsage(ImageSynthesisUsage usage) {
this.usage = usage;
}
@Override
public Long getCreated() {
return created;
}
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public ImageSynthesisTaskMetrics getMetrics() {
return metrics;
}
void setMetrics(ImageSynthesisTaskMetrics metrics) {
this.metrics = metrics;
}
public Long created() {
return this.created;
}
@Override
public String toString() {
return "TongYiImagesResponseMetadata {" + "created=" + created + '}';
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiImagesResponseMetadata that = (TongYiImagesResponseMetadata) o;
return Objects.equals(created, that.created)
&& Objects.equals(taskId, that.taskId)
&& Objects.equals(metrics, that.metrics);
}
@Override
public int hashCode() {
return Objects.hash(created, taskId, metrics);
}
}

View File

@@ -1,53 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.embeddings.TextEmbeddingUsage;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public class TongYiTextEmbeddingResponseMetadata extends EmbeddingResponseMetadata {
private Integer totalTokens;
protected TongYiTextEmbeddingResponseMetadata(Integer totalTokens) {
this.totalTokens = totalTokens;
}
public static TongYiTextEmbeddingResponseMetadata from(TextEmbeddingUsage usage) {
return new TongYiTextEmbeddingResponseMetadata(usage.getTotalTokens());
}
public Integer getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(Integer totalTokens) {
this.totalTokens = totalTokens;
}
}

View File

@@ -1,133 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisUsage;
import com.alibaba.dashscope.audio.tts.timestamp.Sentence;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
private SpeechSynthesisUsage usage;
private String requestId;
private Sentence time;
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";
/**
* NULL objects.
*/
public static final TongYiAudioSpeechResponseMetadata NULL = new TongYiAudioSpeechResponseMetadata() {
};
public static TongYiAudioSpeechResponseMetadata from(SpeechSynthesisResult result) {
Assert.notNull(result, "TongYi AI speech must not be null");
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
return speechResponseMetadata;
}
public static TongYiAudioSpeechResponseMetadata from(String result) {
Assert.notNull(result, "TongYi AI speech must not be null");
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
return speechResponseMetadata;
}
@Nullable
private RateLimit rateLimit;
public TongYiAudioSpeechResponseMetadata() {
this(null);
}
public TongYiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) {
this.rateLimit = rateLimit;
}
@Nullable
public RateLimit getRateLimit() {
RateLimit rateLimit = this.rateLimit;
return rateLimit != null ? rateLimit : new EmptyRateLimit();
}
public TongYiAudioSpeechResponseMetadata withRateLimit(RateLimit rateLimit) {
this.rateLimit = rateLimit;
return this;
}
public TongYiAudioSpeechResponseMetadata withUsage(SpeechSynthesisUsage usage) {
this.usage = usage;
return this;
}
public TongYiAudioSpeechResponseMetadata withRequestId(String id) {
this.requestId = id;
return this;
}
public TongYiAudioSpeechResponseMetadata withSentence(Sentence sentence) {
this.time = sentence;
return this;
}
public SpeechSynthesisUsage getUsage() {
return usage;
}
public String getRequestId() {
return requestId;
}
public Sentence getTime() {
return time;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
}
}

View File

@@ -1,43 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import org.springframework.ai.model.ResultMetadata;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public interface TongYiAudioTranscriptionMetadata extends ResultMetadata {
/**
* A constant instance of {@link TongYiAudioTranscriptionMetadata} that represents a null or empty metadata.
*/
TongYiAudioTranscriptionMetadata NULL = TongYiAudioTranscriptionMetadata.create();
/**
* Factory method for creating a new instance of {@link TongYiAudioTranscriptionMetadata}.
* @return a new instance of {@link TongYiAudioTranscriptionMetadata}
*/
static TongYiAudioTranscriptionMetadata create() {
return new TongYiAudioTranscriptionMetadata() {
};
}
}

View File

@@ -1,98 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionResult;
import com.google.gson.JsonObject;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.util.Assert;
import javax.annotation.Nullable;
import java.util.HashMap;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
@Nullable
private RateLimit rateLimit;
private JsonObject usage;
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";
/**
* NULL objects.
*/
public static final TongYiAudioTranscriptionResponseMetadata NULL = new TongYiAudioTranscriptionResponseMetadata() {
};
protected TongYiAudioTranscriptionResponseMetadata() {
this(null, new JsonObject());
}
protected TongYiAudioTranscriptionResponseMetadata(JsonObject usage) {
this(null, usage);
}
protected TongYiAudioTranscriptionResponseMetadata(@Nullable RateLimit rateLimit, JsonObject usage) {
this.rateLimit = rateLimit;
this.usage = usage;
}
public static TongYiAudioTranscriptionResponseMetadata from(TranscriptionResult result) {
Assert.notNull(result, "TongYi Transcription must not be null");
return new TongYiAudioTranscriptionResponseMetadata(result.getUsage());
}
@Nullable
public RateLimit getRateLimit() {
return this.rateLimit != null ? this.rateLimit : new EmptyRateLimit();
}
public void setRateLimit(@Nullable RateLimit rateLimit) {
this.rateLimit = rateLimit;
}
public JsonObject getUsage() {
return usage;
}
public void setUsage(JsonObject usage) {
this.usage = usage;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
}
}

View File

@@ -1,61 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.autoconfigure.vectorstore.redis;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.RedisVectorStore.RedisVectorStoreConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import redis.clients.jedis.JedisPooled;
/**
* TODO @xin 先拿 spring-ai 最新代码覆盖1.0.0-M1 跟 redis 自动配置会冲突
*
* TODO 这个官方,有说啥时候 fix 哇?
* TODO 看着是列在1.0.0-M2版本
*
* @author Christian Tzolov
* @author Eddú Meléndez
*/
@AutoConfiguration(after = RedisAutoConfiguration.class)
@ConditionalOnClass({JedisPooled.class, JedisConnectionFactory.class, RedisVectorStore.class, EmbeddingModel.class})
@ConditionalOnBean(JedisConnectionFactory.class)
@EnableConfigurationProperties(RedisVectorStoreProperties.class)
public class RedisVectorStoreAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public RedisVectorStore vectorStore(EmbeddingModel embeddingModel, RedisVectorStoreProperties properties,
JedisConnectionFactory jedisConnectionFactory) {
var config = RedisVectorStoreConfig.builder()
.withIndexName(properties.getIndex())
.withPrefix(properties.getPrefix())
.build();
return new RedisVectorStore(config, embeddingModel,
new JedisPooled(jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()),
properties.isInitializeSchema());
}
}

View File

@@ -1,456 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vectorstore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.Schema.FieldType;
import redis.clients.jedis.search.schemafields.*;
import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm;
import java.text.MessageFormat;
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
/**
* The RedisVectorStore is for managing and querying vector data in a Redis database. It
* offers functionalities like adding, deleting, and performing similarity searches on
* documents.
*
* The store utilizes RedisJSON and RedisSearch to handle JSON documents and to index and
* search vector data. It supports various vector algorithms (e.g., FLAT, HSNW) for
* efficient similarity searches. Additionally, it allows for custom metadata fields in
* the documents to be stored alongside the vector and content data.
*
* This class requires a RedisVectorStoreConfig configuration object for initialization,
* which includes settings like Redis URI, index name, field names, and vector algorithms.
* It also requires an EmbeddingModel to convert documents into embeddings before storing
* them.
*
* @author Julien Ruaux
* @author Christian Tzolov
* @author Eddú Meléndez
* @see VectorStore
* @see RedisVectorStoreConfig
* @see EmbeddingModel
*/
public class RedisVectorStore implements VectorStore, InitializingBean {
public enum Algorithm {
FLAT, HSNW
}
public record MetadataField(String name, FieldType fieldType) {
public static MetadataField text(String name) {
return new MetadataField(name, FieldType.TEXT);
}
public static MetadataField numeric(String name) {
return new MetadataField(name, FieldType.NUMERIC);
}
public static MetadataField tag(String name) {
return new MetadataField(name, FieldType.TAG);
}
}
/**
* Configuration for the Redis vector store.
*/
public static final class RedisVectorStoreConfig {
private final String indexName;
private final String prefix;
private final String contentFieldName;
private final String embeddingFieldName;
private final Algorithm vectorAlgorithm;
private final List<MetadataField> metadataFields;
private RedisVectorStoreConfig() {
this(builder());
}
private RedisVectorStoreConfig(Builder builder) {
this.indexName = builder.indexName;
this.prefix = builder.prefix;
this.contentFieldName = builder.contentFieldName;
this.embeddingFieldName = builder.embeddingFieldName;
this.vectorAlgorithm = builder.vectorAlgorithm;
this.metadataFields = builder.metadataFields;
}
/**
* Start building a new configuration.
* @return The entry point for creating a new configuration.
*/
public static Builder builder() {
return new Builder();
}
/**
* {@return the default config}
*/
public static RedisVectorStoreConfig defaultConfig() {
return builder().build();
}
public static class Builder {
private String indexName = DEFAULT_INDEX_NAME;
private String prefix = DEFAULT_PREFIX;
private String contentFieldName = DEFAULT_CONTENT_FIELD_NAME;
private String embeddingFieldName = DEFAULT_EMBEDDING_FIELD_NAME;
private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
private List<MetadataField> metadataFields = new ArrayList<>();
private Builder() {
}
/**
* Configures the Redis index name to use.
* @param name the index name to use
* @return this builder
*/
public Builder withIndexName(String name) {
this.indexName = name;
return this;
}
/**
* Configures the Redis key prefix to use (default: "embedding:").
* @param prefix the prefix to use
* @return this builder
*/
public Builder withPrefix(String prefix) {
this.prefix = prefix;
return this;
}
/**
* Configures the Redis content field name to use.
* @param name the content field name to use
* @return this builder
*/
public Builder withContentFieldName(String name) {
this.contentFieldName = name;
return this;
}
/**
* Configures the Redis embedding field name to use.
* @param name the embedding field name to use
* @return this builder
*/
public Builder withEmbeddingFieldName(String name) {
this.embeddingFieldName = name;
return this;
}
/**
* Configures the Redis vector algorithmto use.
* @param algorithm the vector algorithm to use
* @return this builder
*/
public Builder withVectorAlgorithm(Algorithm algorithm) {
this.vectorAlgorithm = algorithm;
return this;
}
public Builder withMetadataFields(MetadataField... fields) {
return withMetadataFields(Arrays.asList(fields));
}
public Builder withMetadataFields(List<MetadataField> fields) {
this.metadataFields = fields;
return this;
}
/**
* {@return the immutable configuration}
*/
public RedisVectorStoreConfig build() {
return new RedisVectorStoreConfig(this);
}
}
}
private final boolean initializeSchema;
public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
public static final String DEFAULT_PREFIX = "embedding:";
public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
private static final Path2 JSON_SET_PATH = Path2.of("$");
private static final String JSON_PATH_PREFIX = "$.";
private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1l);
private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
private static final String EMBEDDING_PARAM_NAME = "BLOB";
public static final String DISTANCE_FIELD_NAME = "vector_score";
private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
private final JedisPooled jedis;
private final EmbeddingModel embeddingModel;
private final RedisVectorStoreConfig config;
private FilterExpressionConverter filterExpressionConverter;
public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis,
boolean initializeSchema) {
Assert.notNull(config, "Config must not be null");
Assert.notNull(embeddingModel, "Embedding model must not be null");
this.initializeSchema = initializeSchema;
this.jedis = jedis;
this.embeddingModel = embeddingModel;
this.config = config;
this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
}
public JedisPooled getJedis() {
return this.jedis;
}
@Override
public void add(List<Document> documents) {
try (Pipeline pipeline = this.jedis.pipelined()) {
for (Document document : documents) {
var embedding = this.embeddingModel.embed(document);
document.setEmbedding(embedding);
var fields = new HashMap<String, Object>();
fields.put(this.config.embeddingFieldName, embedding);
fields.put(this.config.contentFieldName, document.getContent());
fields.putAll(document.getMetadata());
pipeline.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, fields);
}
List<Object> responses = pipeline.syncAndReturnAll();
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
if (errResponse.isPresent()) {
String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
if (logger.isErrorEnabled()) {
logger.error(message);
}
throw new RuntimeException(message);
}
}
}
private String key(String id) {
return this.config.prefix + id;
}
@Override
public Optional<Boolean> delete(List<String> idList) {
try (Pipeline pipeline = this.jedis.pipelined()) {
for (String id : idList) {
pipeline.jsonDel(key(id));
}
List<Object> responses = pipeline.syncAndReturnAll();
Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
if (errResponse.isPresent()) {
if (logger.isErrorEnabled()) {
logger.error("Could not delete document: {}", errResponse.get());
}
return Optional.of(false);
}
return Optional.of(true);
}
}
@Override
public List<Document> similaritySearch(SearchRequest request) {
Assert.isTrue(request.getTopK() > 0, "The number of documents to returned must be greater than zero");
Assert.isTrue(request.getSimilarityThreshold() >= 0 && request.getSimilarityThreshold() <= 1,
"The similarity score is bounded between 0 and 1; least to most similar respectively.");
String filter = nativeExpressionFilter(request);
String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName,
EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
List<String> returnFields = new ArrayList<>();
this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
returnFields.add(this.config.embeddingFieldName);
returnFields.add(this.config.contentFieldName);
returnFields.add(DISTANCE_FIELD_NAME);
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(embedding))
.returnFields(returnFields.toArray(new String[0]))
.setSortBy(DISTANCE_FIELD_NAME, true)
.dialect(2);
SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
return result.getDocuments()
.stream()
.filter(d -> similarityScore(d) >= request.getSimilarityThreshold())
.map(this::toDocument)
.toList();
}
private Document toDocument(redis.clients.jedis.search.Document doc) {
var id = doc.getId().substring(this.config.prefix.length());
var content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName)
: null;
Map<String, Object> metadata = this.config.metadataFields.stream()
.map(MetadataField::name)
.filter(doc::hasProperty)
.collect(Collectors.toMap(Function.identity(), doc::getString));
metadata.put(DISTANCE_FIELD_NAME, 1 - similarityScore(doc));
return new Document(id, content, metadata);
}
private float similarityScore(redis.clients.jedis.search.Document doc) {
return (2 - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2;
}
private String nativeExpressionFilter(SearchRequest request) {
if (request.getFilterExpression() == null) {
return "*";
}
return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
}
@Override
public void afterPropertiesSet() {
if (!this.initializeSchema) {
return;
}
// If index already exists don't do anything
if (this.jedis.ftList().contains(this.config.indexName)) {
return;
}
String response = this.jedis.ftCreate(this.config.indexName,
FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), schemaFields());
if (!RESPONSE_OK.test(response)) {
String message = MessageFormat.format("Could not create index: {0}", response);
throw new RuntimeException(message);
}
}
private Iterable<SchemaField> schemaFields() {
Map<String, Object> vectorAttrs = new HashMap<>();
vectorAttrs.put("DIM", this.embeddingModel.dimensions());
vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
List<SchemaField> fields = new ArrayList<>();
fields.add(TextField.of(jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
fields.add(VectorField.builder()
.fieldName(jsonPath(this.config.embeddingFieldName))
.algorithm(vectorAlgorithm())
.attributes(vectorAttrs)
.as(this.config.embeddingFieldName)
.build());
if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
for (MetadataField field : this.config.metadataFields) {
fields.add(schemaField(field));
}
}
return fields;
}
private SchemaField schemaField(MetadataField field) {
String fieldName = jsonPath(field.name);
switch (field.fieldType) {
case NUMERIC:
return NumericField.of(fieldName).as(field.name);
case TAG:
return TagField.of(fieldName).as(field.name);
case TEXT:
return TextField.of(fieldName).as(field.name);
default:
throw new IllegalArgumentException(
MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
}
}
private VectorAlgorithm vectorAlgorithm() {
if (config.vectorAlgorithm == Algorithm.HSNW) {
return VectorAlgorithm.HNSW;
}
return VectorAlgorithm.FLAT;
}
private String jsonPath(String field) {
return JSON_PATH_PREFIX + field;
}
private static float[] toFloatArray(List<Double> embeddingDouble) {
float[] embeddingFloat = new float[embeddingDouble.size()];
int i = 0;
for (Double d : embeddingDouble) {
embeddingFloat[i++] = d.floatValue();
}
return embeddingFloat;
}
}

View File

@@ -1,6 +1,5 @@
package cn.iocoder.yudao.framework.ai.chat;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.ClientOptions;
@@ -27,13 +26,13 @@ import static org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatP
*/
public class AzureOpenAIChatModelTests {
private final OpenAIClient openAiApi = (new OpenAIClientBuilder())
// TODO @芋艿:晚点在调整
private final OpenAIClientBuilder openAiApi = new OpenAIClientBuilder()
.endpoint("https://eastusprejade.openai.azure.com")
.credential(new AzureKeyCredential("xxx"))
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"))
.buildClient();
.clientOptions((new ClientOptions()).setApplicationId("spring-ai"));
private final AzureOpenAiChatModel chatModel = new AzureOpenAiChatModel(openAiApi,
AzureOpenAiChatOptions.builder().withDeploymentName(DEFAULT_DEPLOYMENT_NAME).build());
AzureOpenAiChatOptions.builder().deploymentName(DEFAULT_DEPLOYMENT_NAME).build());
@Test
@Disabled

View File

@@ -8,6 +8,9 @@ import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
@@ -20,7 +23,18 @@ import java.util.List;
*/
public class DeepSeekChatModelTests {
private final DeepSeekChatModel chatModel = new DeepSeekChatModel("sk-e94db327cc7d457d99a8de8810fc6b12");
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(DeepSeekChatModel.BASE_URL)
.apiKey("sk-e52047409b144d97b791a6a46a2d") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model("deepseek-chat") // 模型
.temperature(0.7)
.build())
.build();
private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
@Test
@Disabled

View File

@@ -0,0 +1,63 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* 基于 {@link OpenAiChatModel} 集成 Dify 测试
*
* @author 芋道源码
*/
public class DifyChatModelTests {
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl("http://127.0.0.1:3000")
.apiKey("app-4hy2d7fJauSbrKbzTKX1afuP") // apiKey
.build())
.build();
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@@ -0,0 +1,69 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link DouBaoChatModel} 集成测试
*
* @author 芋道源码
*/
public class DouBaoChatModelTests {
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(DouBaoChatModel.BASE_URL)
.apiKey("5c1b5747-26d2-4ebd-a4e0-dd0e8d8b4272") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model("doubao-1-5-lite-32k-250115") // 模型doubao
// .model("deepseek-r1-250120") // 模型deepseek
.temperature(0.7)
.build())
.build();
private final DouBaoChatModel chatModel = new DouBaoChatModel(openAiChatModel);
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
}
// TODO @芋艿:因为使用的是 v1 api导致 deepseek-r1-250120 不返回 think 过程,后续需要优化
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(System.out::println).then().block();
}
}

View File

@@ -0,0 +1,63 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* 基于 {@link OpenAiChatModel} 集成 FastGPT 测试
*
* @author 芋道源码
*/
public class FastGPTChatModelTests {
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl("https://cloud.fastgpt.cn/api")
.apiKey("fastgpt-aqcc61kFtF8CeaglnGAfQOCIDWwjGdJVJHv6hIlMo28otFlva2aZNK") // apiKey
.build())
.build();
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@@ -0,0 +1,110 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link HunYuanChatModel} 集成测试
*
* @author 芋道源码
*/
public class HunYuanChatModelTests {
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(HunYuanChatModel.BASE_URL)
.apiKey("sk-bcd") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(HunYuanChatModel.MODEL_DEFAULT) // 模型
.temperature(0.7)
.build())
.build();
private final HunYuanChatModel chatModel = new HunYuanChatModel(openAiChatModel);
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(System.out::println).then().block();
}
private final OpenAiChatModel deepSeekOpenAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(HunYuanChatModel.DEEP_SEEK_BASE_URL)
.apiKey("sk-abc") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
// .model(HunYuanChatModel.DEEP_SEEK_MODEL_DEFAULT) // 模型("deepseek-v3"
.model("deepseek-r1") // 模型("deepseek-r1"
.temperature(0.7)
.build())
.build();
private final HunYuanChatModel deepSeekChatModel = new HunYuanChatModel(deepSeekOpenAiChatModel);
@Test
@Disabled
public void testCall_deepseek() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = deepSeekChatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
}
@Test
@Disabled
public void testStream_deekseek() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = deepSeekChatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(System.out::println).then().block();
}
}

View File

@@ -23,10 +23,12 @@ import java.util.List;
*/
public class LlamaChatModelTests {
private final OllamaApi ollamaApi = new OllamaApi(
"http://127.0.0.1:11434");
private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
private final OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
.defaultOptions(OllamaOptions.builder()
.model(OllamaModel.LLAMA3.getName()) // 模型
.build())
.build();
@Test
@Disabled

View File

@@ -0,0 +1,62 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.api.MiniMaxApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link MiniMaxChatModel} 的集成测试
*
* @author 芋道源码
*/
public class MiniMaxChatModelTests {
private final MiniMaxChatModel chatModel = new MiniMaxChatModel(
new MiniMaxApi("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLnjovmlofmlowiLCJVc2VyTmFtZSI6IueOi-aWh-aWjCIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxODk3Mjg3MjQ5NDU2ODA4MzQ2IiwiUGhvbmUiOiIxNTYwMTY5MTM5OSIsIkdyb3VwSUQiOiIxODk3Mjg3MjQ5NDQ4NDE5NzM4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjUtMDMtMTEgMTI6NTI6MDIiLCJUb2tlblR5cGUiOjEsImlzcyI6Im1pbmltYXgifQ.aAuB7gWW_oA4IYhh-CF7c9MfWWxKN49B_HK-DYjXaDwwffhiG-H1571z1WQhp9QytWG-DqgLejneeSxkiq1wQIe3FsEP2wz4BmGBct31LehbJu8ehLxg_vg75Uod1nFAHbm5mZz6JSVLNIlSo87Xr3UtSzJhAXlapEkcqlA4YOzOpKrZ8l5_OJPTORTCmHWZYgJcRS-faNiH62ZnUEHUozesTFhubJHo5GfJCw_edlnmfSUocERV1BjWvenhZ9My-aYXNktcW9WaSj9l6gayV7A0Ium_PL55T9ln1PcI8gayiVUKJGJDoqNyF1AF9_aF9NOKtTnQzwNqnZdlTYH6hw"), // 密钥
MiniMaxChatOptions.builder()
.model(MiniMaxApi.ChatModel.ABAB_6_5_G_Chat.getValue()) // 模型
.build());
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@@ -0,0 +1,62 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.moonshot.MoonshotChatModel;
import org.springframework.ai.moonshot.MoonshotChatOptions;
import org.springframework.ai.moonshot.api.MoonshotApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link org.springframework.ai.moonshot.MoonshotChatModel} 的集成测试
*
* @author 芋道源码
*/
public class MoonshotChatModelTests {
private final MoonshotChatModel chatModel = new MoonshotChatModel(
new MoonshotApi("sk-aHYYV1SARscItye5QQRRNbXij4fy65Ee7pNZlC9gsSQnUKXA"), // 密钥
MoonshotChatOptions.builder()
.model("moonshot-v1-8k") // 模型
.build());
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@@ -0,0 +1,65 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link OllamaChatModel} 集成测试
*
* @author 芋道源码
*/
public class OllamaChatModelTests {
private final OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi("http://127.0.0.1:11434")) // Ollama 服务地址
.defaultOptions(OllamaOptions.builder()
// .model("qwen") // 模型https://ollama.com/library/qwen
.model("deepseek-r1") // 模型https://ollama.com/library/deepseek-r1
.build())
.build();
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}

View File

@@ -22,11 +22,16 @@ import java.util.List;
*/
public class OpenAIChatModelTests {
private final OpenAiApi openAiApi = new OpenAiApi(
"https://api.holdai.top",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf");
private final OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi,
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O).build());
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl("https://api.holdai.top")
.apiKey("sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(OpenAiApi.ChatModel.GPT_4_O) // 模型
.temperature(0.7)
.build())
.build();
@Test
@Disabled

View File

@@ -0,0 +1,69 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* {@link SiliconFlowChatModel} 集成测试
*
* @author 芋道源码
*/
public class SiliconFlowChatModelTests {
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(SiliconFlowChatModel.BASE_URL)
.apiKey("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // apiKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model(SiliconFlowChatModel.MODEL_DEFAULT) // 模型
// .model("deepseek-ai/DeepSeek-R1") // 模型deepseek-ai/DeepSeek-R1可用赠费
// .model("Pro/deepseek-ai/DeepSeek-R1") // 模型Pro/deepseek-ai/DeepSeek-R1需要付费
.temperature(0.7)
.build())
.build();
private final SiliconFlowChatModel chatModel = new SiliconFlowChatModel(openAiChatModel);
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(System.out::println).then().block();
}
}

View File

@@ -1,12 +1,8 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.hutool.core.util.ReflectUtil;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.utils.Constants;
import org.junit.jupiter.api.BeforeEach;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
@@ -20,25 +16,20 @@ import java.util.ArrayList;
import java.util.List;
/**
* {@link TongYiChatModel} 集成测试类
* {@link DashScopeChatModel} 集成测试类
*
* @author fansili
*/
public class TongYiChatModelTests {
private final Generation generation = new Generation();
private final TongYiChatModel chatModel = new TongYiChatModel(generation,
TongYiChatOptions.builder().withModel("qwen1.5-72b-chat").build());
static {
Constants.apiKey = "sk-Zsd81gZYg7";
}
@BeforeEach
public void before() {
// 防止 TongYiChatModel 调用空指针
ReflectUtil.setFieldValue(chatModel, "msgManager", new MessageManager());
}
private final DashScopeChatModel chatModel = new DashScopeChatModel(
new DashScopeApi("sk-7d903764249848cfa912733146da12d1"),
DashScopeChatOptions.builder()
.withModel("qwen1.5-72b-chat") // 模型
// .withModel("deepseek-r1") // 模型deepseek-r1
// .withModel("deepseek-v3") // 模型deepseek-v3
// .withModel("deepseek-r1-distill-qwen-1.5b") // 模型deepseek-r1-distill-qwen-1.5b
.build());
@Test
@Disabled

View File

@@ -8,6 +8,9 @@ import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
@@ -20,9 +23,18 @@ import java.util.List;
*/
public class XingHuoChatModelTests {
private final XingHuoChatModel chatModel = new XingHuoChatModel(
"cb6415c19d6162cda07b47316fcb0416",
"Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh");
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl(XingHuoChatModel.BASE_URL)
.apiKey("75b161ed2aef4719b275d6e7f2a4d4cd:YWYxYWI2MTA4ODI2NGZlYTQyNjAzZTcz") // appKey:secretKey
.build())
.defaultOptions(OpenAiChatOptions.builder()
.model("generalv3.5") // 模型
.temperature(0.7)
.build())
.build();
private final XingHuoChatModel chatModel = new XingHuoChatModel(openAiChatModel);
@Test
@Disabled

View File

@@ -14,6 +14,7 @@ import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
// TODO @芋艿:百度千帆 API 提供了 V2 版本,目前 Spring AI 不兼容,可关键 <https://github.com/spring-projects/spring-ai/issues/2179> 进展
/**
* {@link QianFanChatModel} 的集成测试
*
@@ -21,11 +22,11 @@ import java.util.List;
*/
public class YiYanChatModelTests {
private final QianFanApi qianFanApi = new QianFanApi(
"qS8k8dYr2nXunagK4SSU8Xjj",
"pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
private final QianFanChatModel chatModel = new QianFanChatModel(
new QianFanApi("qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e"), // 密钥
QianFanChatOptions.builder()
.model(QianFanApi.ChatModel.ERNIE_4_0_8K_Preview.getValue())
.build()
);
@Test

View File

@@ -22,9 +22,12 @@ import java.util.List;
*/
public class ZhiPuAiChatModelTests {
private final ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs");
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi,
ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4.getModelName()).build());
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(
new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs"), // 密钥
ZhiPuAiChatOptions.builder()
.model(ZhiPuAiApi.ChatModel.GLM_4.getName()) // 模型
.build()
);
@Test
@Disabled

View File

@@ -15,8 +15,8 @@ import java.util.List;
public class MidjourneyApiTests {
private final MidjourneyApi midjourneyApi = new MidjourneyApi(
"https://api.holdai.top/mj",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
"https://api.holdai.top/mj", // 链接
"sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17", // 密钥
null);
@Test

View File

@@ -8,7 +8,6 @@ import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.web.client.RestClient;
/**
* {@link OpenAiImageModel} 集成测试类
@@ -17,11 +16,10 @@ import org.springframework.web.client.RestClient;
*/
public class OpenAiImageModelTests {
private final OpenAiImageApi imageApi = new OpenAiImageApi(
"https://api.holdai.top",
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
RestClient.builder());
private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
private final OpenAiImageModel imageModel = new OpenAiImageModel(OpenAiImageApi.builder()
.baseUrl("https://api.holdai.top") // apiKey
.apiKey("sk-aN6nWn3fILjrgLFT0fC4Aa60B72e4253826c77B29dC94f17")
.build());
@Test
@Disabled

View File

@@ -10,14 +10,15 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
// TODO @芋艿:百度千帆 API 提供了 V2 版本,目前 Spring AI 不兼容,可关键 <https://github.com/spring-projects/spring-ai/issues/2179> 进展
/**
* {@link QianFanImageModel} 集成测试类
*/
public class QianFanImageTests {
private final QianFanImageApi imageApi = new QianFanImageApi(
"qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
private final QianFanImageModel imageModel = new QianFanImageModel(
new QianFanImageApi("qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e")); // 密钥
@Test
@Disabled
@@ -25,9 +26,9 @@ public class QianFanImageTests {
// 准备参数
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.withWidth(1024).withHeight(1024)
.withN(1)
.model(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
.width(1024).height(1024)
.N(1)
.build();
ImagePrompt prompt = new ImagePrompt("good", imageOptions);

View File

@@ -22,9 +22,9 @@ import java.util.concurrent.TimeUnit;
*/
public class StabilityAiImageModelTests {
private final StabilityAiApi imageApi = new StabilityAiApi(
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(
new StabilityAiApi("sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx") // 密钥
);
@Test
@Disabled
@@ -32,7 +32,7 @@ public class StabilityAiImageModelTests {
// 准备参数
ImageOptions options = OpenAiImageOptions.builder()
.withModel("stable-diffusion-v1-6")
.withHeight(256).withWidth(256)
.withHeight(320).withWidth(320)
.build();
ImagePrompt prompt = new ImagePrompt("great wall", options);

View File

@@ -1,35 +1,30 @@
package cn.iocoder.yudao.framework.ai.image;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.utils.Constants;
import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
/**
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
* {@link DashScopeImageModel} 集成测试类
*
* @author fansili
*/
public class TongYiImagesModelTest {
private final ImageSynthesis imageApi = new ImageSynthesis();
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
static {
Constants.apiKey = "sk-Zsd81gZYg7";
}
private final DashScopeImageModel imageModel = new DashScopeImageModel(
new DashScopeImageApi("sk-7d903764249848cfa912733146da12d1"));
@Test
@Disabled
public void imageCallTest() {
// 准备参数
ImageOptions options = OpenAiImageOptions.builder()
.withModel(ImageSynthesis.Models.WANX_V1)
ImageOptions options = DashScopeImageOptions.builder()
.withModel("wanx-v1")
.withHeight(256).withWidth(256)
.build();
ImagePrompt prompt = new ImagePrompt("中国长城!", options);

View File

@@ -13,16 +13,16 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
*/
public class ZhiPuAiImageModelTests {
private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
"78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(
new ZhiPuAiImageApi("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy") // 密钥
);
@Test
@Disabled
public void testCall() {
// 准备参数
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
.model(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
.build();
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);