【模块新增】AI:支持通义千问、文心一言、讯飞星火、智谱、DeepSeek 等国内外大模型能力
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
package cn.iocoder.yudao.framework.ai.config;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Import;
|
||||
|
||||
/**
|
||||
* 芋道 AI 自动配置
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@AutoConfiguration
|
||||
@EnableConfigurationProperties(YudaoAiProperties.class)
|
||||
@Slf4j
|
||||
@Import(TongYiAutoConfiguration.class)
|
||||
public class YudaoAiAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
public AiModelFactory aiModelFactory() {
|
||||
return new AiModelFactoryImpl();
|
||||
}
|
||||
|
||||
// ========== 各种 AI Client 创建 ==========
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.deepseek.enable", havingValue = "true")
|
||||
public DeepSeekChatModel deepSeekChatModel(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.DeepSeekProperties properties = yudaoAiProperties.getDeepSeek();
|
||||
DeepSeekChatOptions options = DeepSeekChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build();
|
||||
return new DeepSeekChatModel(properties.getApiKey(), options);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
|
||||
public XingHuoChatModel xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.XingHuoProperties properties = yudaoAiProperties.getXinghuo();
|
||||
XingHuoChatOptions options = XingHuoChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topK(properties.getTopK())
|
||||
.build();
|
||||
return new XingHuoChatModel(properties.getAppKey(), properties.getSecretKey(), options);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
|
||||
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.MidjourneyProperties config = yudaoAiProperties.getMidjourney();
|
||||
return new MidjourneyApi(config.getBaseUrl(), config.getApiKey(), config.getNotifyUrl());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.suno.enable", havingValue = "true")
|
||||
public SunoApi sunoApi(YudaoAiProperties yudaoAiProperties) {
|
||||
return new SunoApi(yudaoAiProperties.getSuno().getBaseUrl());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package cn.iocoder.yudao.framework.ai.config;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
/**
|
||||
* 芋道 AI 配置类
|
||||
*
|
||||
* @author fansili
|
||||
* @since 1.0
|
||||
*/
|
||||
@ConfigurationProperties(prefix = "yudao.ai")
|
||||
@Data
|
||||
public class YudaoAiProperties {
|
||||
|
||||
/**
|
||||
* DeepSeek
|
||||
*/
|
||||
private DeepSeekProperties deepSeek;
|
||||
|
||||
/**
|
||||
* 讯飞星火
|
||||
*/
|
||||
private XingHuoProperties xinghuo;
|
||||
|
||||
/**
|
||||
* Midjourney 绘图
|
||||
*/
|
||||
private MidjourneyProperties midjourney;
|
||||
|
||||
/**
|
||||
* Suno 音乐
|
||||
*/
|
||||
private SunoProperties suno;
|
||||
|
||||
@Data
|
||||
public static class XingHuoProperties {
|
||||
|
||||
private String enable;
|
||||
private String appId;
|
||||
private String appKey;
|
||||
private String secretKey;
|
||||
|
||||
private String model;
|
||||
private Float temperature;
|
||||
private Integer maxTokens;
|
||||
private Integer topK;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class DeepSeekProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Float temperature;
|
||||
private Integer maxTokens;
|
||||
private Float topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class MidjourneyProperties {
|
||||
|
||||
private String enable;
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
private String notifyUrl;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class SunoProperties {
|
||||
|
||||
private boolean enable = false;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.enums;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* AI 模型平台
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum AiPlatformEnum {
|
||||
|
||||
// ========== 国内平台 ==========
|
||||
|
||||
TONG_YI("TongYi", "通义千问"), // 阿里
|
||||
YI_YAN("YiYan", "文心一言"), // 百度
|
||||
DEEP_SEEK("DeepSeek", "DeepSeek"), // DeepSeek
|
||||
ZHI_PU("ZhiPu", "智谱"), // 智谱 AI
|
||||
XING_HUO("XingHuo", "星火"), // 讯飞
|
||||
|
||||
// ========== 国外平台 ==========
|
||||
|
||||
OPENAI("OpenAI", "OpenAI"),
|
||||
OLLAMA("Ollama", "Ollama"),
|
||||
|
||||
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
||||
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
|
||||
SUNO("Suno", "Suno"), // Suno AI
|
||||
|
||||
;
|
||||
|
||||
/**
|
||||
* 平台
|
||||
*/
|
||||
private final String platform;
|
||||
/**
|
||||
* 平台名
|
||||
*/
|
||||
private final String name;
|
||||
|
||||
public static AiPlatformEnum validatePlatform(String platform) {
|
||||
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {
|
||||
if (platformEnum.getPlatform().equals(platform)) {
|
||||
return platformEnum;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("非法平台: " + platform);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.factory;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
|
||||
/**
|
||||
* AI Model 模型工厂的接口类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public interface AiModelFactory {
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 ChatModel 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param platform 平台
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于默认配置,获得 ChatModel 对象
|
||||
*
|
||||
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
||||
*
|
||||
* @param platform 平台
|
||||
* @return ChatModel 对象
|
||||
*/
|
||||
ChatModel getDefaultChatModel(AiPlatformEnum platform);
|
||||
|
||||
/**
|
||||
* 基于默认配置,获得 ImageModel 对象
|
||||
*
|
||||
* 默认配置,指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
|
||||
*
|
||||
* @param platform 平台
|
||||
* @return ImageModel 对象
|
||||
*/
|
||||
ImageModel getDefaultImageModel(AiPlatformEnum platform);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 ImageModel 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param platform 平台
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return ImageModel 对象
|
||||
*/
|
||||
ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 MidjourneyApi 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return MidjourneyApi 对象
|
||||
*/
|
||||
MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url);
|
||||
|
||||
/**
|
||||
* 基于指定配置,获得 SunoApi 对象
|
||||
*
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param apiKey API KEY
|
||||
* @param url API URL
|
||||
* @return SunoApi 对象
|
||||
*/
|
||||
SunoApi getOrCreateSunoApi(String apiKey, String url);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.factory;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.lang.Singleton;
|
||||
import cn.hutool.core.lang.func.Func0;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
|
||||
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.api.ApiUtils;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.qianfan.QianFanChatModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AI Model 模型工厂的实现类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiChatModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildYiYanChatModel(apiKey);
|
||||
case DEEP_SEEK:
|
||||
return buildDeepSeekChatModel(apiKey);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuChatModel(apiKey, url);
|
||||
case XING_HUO:
|
||||
return buildXingHuoChatModel(apiKey);
|
||||
case OPENAI:
|
||||
return buildOpenAiChatModel(apiKey, url);
|
||||
case OLLAMA:
|
||||
return buildOllamaChatModel(url);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(TongYiChatModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanChatModel.class);
|
||||
case DEEP_SEEK:
|
||||
return SpringUtil.getBean(DeepSeekChatModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiChatModel.class);
|
||||
case XING_HUO:
|
||||
return SpringUtil.getBean(XingHuoChatModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiChatModel.class);
|
||||
case OLLAMA:
|
||||
return SpringUtil.getBean(OllamaChatModel.class);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(TongYiImagesModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanImageModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiImageModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiImageModel.class);
|
||||
case STABLE_DIFFUSION:
|
||||
return SpringUtil.getBean(StabilityAiImageModel.class);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiImagesModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildQianFanImageModel(apiKey);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuAiImageModel(apiKey, url);
|
||||
case OPENAI:
|
||||
return buildOpenAiImageModel(apiKey, url);
|
||||
case STABLE_DIFFUSION:
|
||||
return buildStabilityAiImageModel(apiKey, url);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
|
||||
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
|
||||
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public SunoApi getOrCreateSunoApi(String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(SunoApi.class, AiPlatformEnum.SUNO.getPlatform(), apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
|
||||
}
|
||||
|
||||
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
|
||||
if (ArrayUtil.isEmpty(params)) {
|
||||
return clazz.getName();
|
||||
}
|
||||
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
|
||||
}
|
||||
|
||||
// ========== 各种创建 spring-ai 客户端的方法 ==========
|
||||
|
||||
/**
|
||||
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
|
||||
*/
|
||||
private static TongYiChatModel buildTongYiChatModel(String key) {
|
||||
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
|
||||
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
|
||||
// TODO @芋艿:貌似 apiKey 是全局唯一的???得测试下
|
||||
// TODO @芋艿:貌似阿里云不是增量返回的
|
||||
// 该 issue 进行跟进中 https://github.com/alibaba/spring-cloud-alibaba/issues/3790
|
||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
||||
connectionProperties.setApiKey(key);
|
||||
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
|
||||
}
|
||||
|
||||
private static TongYiImagesModel buildTongYiImagesModel(String key) {
|
||||
ImageSynthesis imageSynthesis = SpringUtil.getBean(ImageSynthesis.class);
|
||||
TongYiImagesProperties imagesOptions = SpringUtil.getBean(TongYiImagesProperties.class);
|
||||
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
|
||||
connectionProperties.setApiKey(key);
|
||||
return new TongYiAutoConfiguration().tongYiImagesClient(imageSynthesis, imagesOptions, connectionProperties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private static QianFanChatModel buildYiYanChatModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
|
||||
return new QianFanChatModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link QianFanAutoConfiguration#qianFanImageModel(QianFanConnectionProperties, QianFanImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private QianFanImageModel buildQianFanImageModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
|
||||
String appKey = keys.get(0);
|
||||
String secretKey = keys.get(1);
|
||||
QianFanImageApi qianFanApi = new QianFanImageApi(appKey, secretKey);
|
||||
return new QianFanImageModel(qianFanApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link YudaoAiAutoConfiguration#deepSeekChatModel(YudaoAiProperties)}
|
||||
*/
|
||||
private static DeepSeekChatModel buildDeepSeekChatModel(String apiKey) {
|
||||
return new DeepSeekChatModel(apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
|
||||
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||
ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi(url, apiKey);
|
||||
return new ZhiPuAiChatModel(zhiPuAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
|
||||
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||
ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
|
||||
return new ZhiPuAiImageModel(zhiPuAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private static XingHuoChatModel buildXingHuoChatModel(String key) {
|
||||
List<String> keys = StrUtil.split(key, '|');
|
||||
Assert.equals(keys.size(), 3, "XingHuoChatClient 的密钥需要 (appid|appKey|secretKey) 格式");
|
||||
String appKey = keys.get(1);
|
||||
String secretKey = keys.get(2);
|
||||
return new XingHuoChatModel(appKey, secretKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration}
|
||||
*/
|
||||
private static OpenAiChatModel buildOpenAiChatModel(String openAiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
|
||||
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration}
|
||||
*/
|
||||
private OpenAiImageModel buildOpenAiImageModel(String openAiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
|
||||
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
|
||||
return new OpenAiImageModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OllamaAutoConfiguration}
|
||||
*/
|
||||
private static OllamaChatModel buildOllamaChatModel(String url) {
|
||||
OllamaApi ollamaApi = new OllamaApi(url);
|
||||
return new OllamaChatModel(ollamaApi);
|
||||
}
|
||||
|
||||
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
|
||||
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
|
||||
return new StabilityAiImageModel(stabilityAiApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.deepseek;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions.MODEL_DEFAULT;
|
||||
|
||||
/**
|
||||
* DeepSeek {@link ChatModel} 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
public class DeepSeekChatModel implements ChatModel {
|
||||
|
||||
private static final String BASE_URL = "https://api.deepseek.com";
|
||||
|
||||
private final DeepSeekChatOptions defaultOptions;
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
/**
|
||||
* DeepSeek 兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
|
||||
*
|
||||
* 不过要注意,DeepSeek 没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
|
||||
*/
|
||||
private final OpenAiApi openAiApi;
|
||||
|
||||
public DeepSeekChatModel(String apiKey) {
|
||||
this(apiKey, DeepSeekChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
|
||||
}
|
||||
|
||||
public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options) {
|
||||
this(apiKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public DeepSeekChatModel(String apiKey, DeepSeekChatOptions options, RetryTemplate retryTemplate) {
|
||||
Assert.notEmpty(apiKey, "apiKey 不能为空");
|
||||
Assert.notNull(options, "options 不能为空");
|
||||
Assert.notNull(retryTemplate, "retryTemplate 不能为空");
|
||||
this.openAiApi = new OpenAiApi(BASE_URL, apiKey);
|
||||
this.defaultOptions = options;
|
||||
this.retryTemplate = retryTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
|
||||
return this.retryTemplate.execute(ctx -> {
|
||||
// 1.1 发起调用
|
||||
ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
|
||||
// 1.2 校验结果
|
||||
OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
|
||||
if (chatCompletion == null) {
|
||||
log.warn("No chat completion returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
|
||||
if (choices == null) {
|
||||
log.warn("No choices returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
|
||||
// 2. 转换 ChatResponse 返回
|
||||
List<Generation> generations = choices.stream().map(choice -> {
|
||||
Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
|
||||
if (choice.finishReason() != null) {
|
||||
generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
||||
}
|
||||
return generation;
|
||||
}).toList();
|
||||
return new ChatResponse(generations,
|
||||
OpenAiChatResponseMetadata.from(completionEntity.getBody()));
|
||||
});
|
||||
}
|
||||
|
||||
private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
OpenAiApi.ChatCompletionMessage message = choice.message();
|
||||
if (message.role() != null) {
|
||||
map.put("role", message.role().name());
|
||||
}
|
||||
if (choice.finishReason() != null) {
|
||||
map.put("finishReason", choice.finishReason().name());
|
||||
}
|
||||
map.put("id", id);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
|
||||
return this.retryTemplate.execute(ctx -> {
|
||||
// 1. 发起调用
|
||||
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
|
||||
return response.map(chatCompletion -> {
|
||||
String id = chatCompletion.id();
|
||||
// 2. 转换 ChatResponse 返回
|
||||
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
|
||||
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
|
||||
String role = (choice.delta().role() != null ? choice.delta().role().name() : "");
|
||||
if (choice.finishReason() == OpenAiApi.ChatCompletionFinishReason.STOP) {
|
||||
// 兜底处理 DeepSeek 返回 STOP 时,role 为空的情况
|
||||
role = OpenAiApi.ChatCompletionMessage.Role.ASSISTANT.name();
|
||||
}
|
||||
Generation generation = new Generation(choice.delta().content(),
|
||||
Map.of("id", id, "role", role, "finishReason", finish));
|
||||
if (choice.finishReason() != null) {
|
||||
generation = generation.withGenerationMetadata(
|
||||
ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
||||
}
|
||||
return generation;
|
||||
}).toList();
|
||||
return new ChatResponse(generations);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
||||
// 1. 构建 ChatCompletionMessage 对象
|
||||
List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
|
||||
new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
|
||||
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
|
||||
|
||||
// 2.1 补充 prompt 内置的 options
|
||||
if (prompt.getOptions() != null) {
|
||||
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
|
||||
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
|
||||
ChatOptions.class, OpenAiChatOptions.class);
|
||||
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
|
||||
+ prompt.getOptions().getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
// 2.2 补充默认 options
|
||||
if (this.defaultOptions != null) {
|
||||
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return DeepSeekChatOptions.fromOptions(defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.deepseek;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
|
||||
/**
|
||||
* DeepSeek {@link ChatOptions} 实现类
|
||||
*
|
||||
* 参考文档:<a href="https://platform.deepseek.com/api-docs/zh-cn/">快速开始</a>
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class DeepSeekChatOptions implements ChatOptions {
|
||||
|
||||
public static final String MODEL_DEFAULT = "deepseek-chat";
|
||||
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
private String model;
|
||||
/**
|
||||
* 温度
|
||||
*/
|
||||
private Float temperature;
|
||||
/**
|
||||
* 最大 Token
|
||||
*/
|
||||
private Integer maxTokens;
|
||||
/**
|
||||
* topP
|
||||
*/
|
||||
private Float topP;
|
||||
|
||||
@Override
|
||||
public Integer getTopK() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
|
||||
return DeepSeekChatOptions.builder()
|
||||
.model(fromOptions.getModel())
|
||||
.temperature(fromOptions.getTemperature())
|
||||
.maxTokens(fromOptions.getMaxTokens())
|
||||
.topP(fromOptions.getTopP())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.openai.api.ApiUtils;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Midjourney API
|
||||
*
|
||||
* @author fansili
|
||||
* @since 1.0
|
||||
*/
|
||||
@Slf4j
|
||||
public class MidjourneyApi {
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[midjourney-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[midjourney-api] 调用失败!"));
|
||||
});
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
/**
|
||||
* 回调地址
|
||||
*/
|
||||
private final String notifyUrl;
|
||||
|
||||
public MidjourneyApi(String baseUrl, String apiKey, String notifyUrl) {
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.defaultHeaders(ApiUtils.getJsonContentHeaders(apiKey))
|
||||
.build();
|
||||
this.notifyUrl = notifyUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* imagine - 根据提示词提交绘画任务
|
||||
*
|
||||
* @param request 请求
|
||||
* @return 提交结果
|
||||
*/
|
||||
public SubmitResponse imagine(ImagineRequest request) {
|
||||
if (StrUtil.isEmpty(request.getNotifyHook())) {
|
||||
request.setNotifyHook(notifyUrl);
|
||||
}
|
||||
String response = post("/submit/imagine", request);
|
||||
return JsonUtils.parseObject(response, SubmitResponse.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* action - 放大、缩小、U1、U2...
|
||||
*
|
||||
* @param request 请求
|
||||
* @return 提交结果
|
||||
*/
|
||||
public SubmitResponse action(ActionRequest request) {
|
||||
if (StrUtil.isEmpty(request.getNotifyHook())) {
|
||||
request.setNotifyHook(notifyUrl);
|
||||
}
|
||||
String response = post("/submit/action", request);
|
||||
return JsonUtils.parseObject(response, SubmitResponse.class);
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量查询 task 任务
|
||||
*
|
||||
* @param ids 任务编号数组
|
||||
* @return task 任务
|
||||
*/
|
||||
public List<Notify> getTaskList(Collection<String> ids) {
|
||||
String res = post("/task/list-by-condition", ImmutableMap.of("ids", ids));
|
||||
return JsonUtils.parseArray(res, Notify.class);
|
||||
}
|
||||
|
||||
private String post(String uri, Object body) {
|
||||
return webClient.post()
|
||||
.uri(uri)
|
||||
.body(Mono.just(JsonUtils.toJsonString(body)), String.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(body))
|
||||
.bodyToMono(String.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
// ========== record 结构 ==========
|
||||
|
||||
/**
|
||||
* Imagine 请求(生成图片)
|
||||
*/
|
||||
@Data
|
||||
public static final class ImagineRequest {
|
||||
|
||||
/**
|
||||
* 垫图(参考图) base64 数组
|
||||
*/
|
||||
private List<String> base64Array;
|
||||
/**
|
||||
* 提示词
|
||||
*/
|
||||
private String prompt;
|
||||
/**
|
||||
* 通知地址
|
||||
*/
|
||||
private String notifyHook;
|
||||
/**
|
||||
* 自定义参数
|
||||
*/
|
||||
private String state;
|
||||
|
||||
public ImagineRequest(List<String> base64Array, String prompt, String notifyHook, String state) {
|
||||
this.base64Array = base64Array;
|
||||
this.prompt = prompt;
|
||||
this.notifyHook = notifyHook;
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
public static String buildState(Integer width, Integer height, String version, String model) {
|
||||
StringBuilder params = new StringBuilder();
|
||||
// --ar 来设置尺寸
|
||||
params.append(String.format(" --ar %s:%s ", width, height));
|
||||
// --niji 模型
|
||||
if (ModelEnum.NIJI.getModel().equals(model)) {
|
||||
params.append(String.format(" --niji %s ", version));
|
||||
} else {
|
||||
params.append(String.format(" --v %s ", version));
|
||||
}
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Action 请求
|
||||
*/
|
||||
@Data
|
||||
public static final class ActionRequest {
|
||||
|
||||
private String customId;
|
||||
private String taskId;
|
||||
private String notifyHook;
|
||||
|
||||
public ActionRequest(String taskId, String customId, String notifyHook) {
|
||||
this.customId = customId;
|
||||
this.taskId = taskId;
|
||||
this.notifyHook = notifyHook;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit 统一返回
|
||||
*
|
||||
* @param code 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
|
||||
* @param description 描述
|
||||
* @param properties 扩展字段
|
||||
* @param result 任务ID
|
||||
*/
|
||||
public record SubmitResponse(String code,
|
||||
String description,
|
||||
Map<String, Object> properties,
|
||||
String result) {
|
||||
}
|
||||
|
||||
/**
|
||||
* 通知 request
|
||||
*
|
||||
* @param id job id
|
||||
* @param action 任务类型 {@link TaskActionEnum}
|
||||
* @param status 任务状态 {@link TaskStatusEnum}
|
||||
* @param prompt 提示词
|
||||
* @param promptEn 提示词-英文
|
||||
* @param description 任务描述
|
||||
* @param state 自定义参数
|
||||
* @param submitTime 提交时间
|
||||
* @param startTime 开始执行时间
|
||||
* @param finishTime 结束时间
|
||||
* @param imageUrl 图片url
|
||||
* @param progress 任务进度
|
||||
* @param failReason 失败原因
|
||||
* @param buttons 任务完成后的可执行按钮
|
||||
*/
|
||||
public record Notify(String id,
|
||||
String action,
|
||||
String status,
|
||||
|
||||
String prompt,
|
||||
String promptEn,
|
||||
|
||||
String description,
|
||||
String state,
|
||||
|
||||
Long submitTime,
|
||||
Long startTime,
|
||||
Long finishTime,
|
||||
|
||||
String imageUrl,
|
||||
String progress,
|
||||
String failReason,
|
||||
List<Button> buttons) {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* button
|
||||
*
|
||||
* @param customId MJ::JOB::upsample::1::85a4b4c1-8835-46c5-a15c-aea34fad1862 动作标识
|
||||
* @param emoji 图标 emoji
|
||||
* @param label Make Variations 文本
|
||||
* @param type 类型,系统内部使用
|
||||
* @param style 样式: 2(Primary)、3(Green)
|
||||
*/
|
||||
public record Button(String customId,
|
||||
String emoji,
|
||||
String label,
|
||||
String type,
|
||||
String style) {
|
||||
}
|
||||
|
||||
// ============ enums ============
|
||||
|
||||
/**
|
||||
* 模型枚举
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public enum ModelEnum {
|
||||
|
||||
MIDJOURNEY("midjourney", "midjourney"),
|
||||
NIJI("niji", "niji"),
|
||||
;
|
||||
|
||||
private final String model;
|
||||
private final String name;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交返回的状态码的枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum SubmitCodeEnum {
|
||||
|
||||
SUBMIT_SUCCESS("1", "提交成功"),
|
||||
ALREADY_EXISTS("21", "已存在"),
|
||||
QUEUING("22", "排队中"),
|
||||
;
|
||||
|
||||
public static final List<String> SUCCESS_CODES = Lists.newArrayList(
|
||||
SUBMIT_SUCCESS.code,
|
||||
ALREADY_EXISTS.code,
|
||||
QUEUING.code
|
||||
);
|
||||
|
||||
private final String code;
|
||||
private final String name;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Action 枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum TaskActionEnum {
|
||||
|
||||
/**
|
||||
* 生成图片
|
||||
*/
|
||||
IMAGINE,
|
||||
/**
|
||||
* 选中放大
|
||||
*/
|
||||
UPSCALE,
|
||||
/**
|
||||
* 选中其中的一张图,生成四张相似的
|
||||
*/
|
||||
VARIATION,
|
||||
/**
|
||||
* 重新执行
|
||||
*/
|
||||
REROLL,
|
||||
/**
|
||||
* 图转 prompt
|
||||
*/
|
||||
DESCRIBE,
|
||||
/**
|
||||
* 多图混合
|
||||
*/
|
||||
BLEND
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 任务状态枚举
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum TaskStatusEnum {
|
||||
|
||||
/**
|
||||
* 未启动
|
||||
*/
|
||||
NOT_START(0),
|
||||
/**
|
||||
* 已提交
|
||||
*/
|
||||
SUBMITTED(1),
|
||||
/**
|
||||
* 执行中
|
||||
*/
|
||||
IN_PROGRESS(3),
|
||||
/**
|
||||
* 失败
|
||||
*/
|
||||
FAILURE(4),
|
||||
/**
|
||||
* 成功
|
||||
*/
|
||||
SUCCESS(4);
|
||||
|
||||
private final int order;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.suno.api;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.text.StrPool;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Suno API
|
||||
* <p>
|
||||
* 对接 Suno Proxy:<a href="https://github.com/gcui-art/suno-api">suno-api</a>
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
@Slf4j
|
||||
public class SunoApi {
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[suno-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[suno-api] 调用失败!"));
|
||||
});
|
||||
|
||||
public SunoApi(String baseUrl) {
|
||||
this.webClient = WebClient.builder()
|
||||
.baseUrl(baseUrl)
|
||||
.defaultHeaders((headers) -> headers.setContentType(MediaType.APPLICATION_JSON))
|
||||
.build();
|
||||
}
|
||||
|
||||
public List<MusicData> generate(MusicGenerateRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/generate")
|
||||
.body(Mono.just(request), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public List<MusicData> customGenerate(MusicGenerateRequest request) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/custom_generate")
|
||||
.body(Mono.just(request), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(request))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public LyricsData generateLyrics(String prompt) {
|
||||
return this.webClient.post()
|
||||
.uri("/api/generate_lyrics")
|
||||
.body(Mono.just(new MusicGenerateRequest(prompt)), MusicGenerateRequest.class)
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(prompt))
|
||||
.bodyToMono(LyricsData.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
public List<MusicData> getMusicList(List<String> ids) {
|
||||
return this.webClient.get()
|
||||
.uri(uriBuilder -> uriBuilder
|
||||
.path("/api/get")
|
||||
.queryParam("ids", CollUtil.join(ids, StrPool.COMMA))
|
||||
.build())
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(ids))
|
||||
.bodyToMono(new ParameterizedTypeReference<List<MusicData>>() {
|
||||
})
|
||||
.block();
|
||||
}
|
||||
|
||||
public LimitUsageData getLimitUsage() {
|
||||
return this.webClient.get()
|
||||
.uri("/api/get_limit")
|
||||
.retrieve()
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(null))
|
||||
.bodyToMono(LimitUsageData.class)
|
||||
.block();
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据提示生成音频
|
||||
*
|
||||
* @param prompt 用于生成音乐音频的提示
|
||||
* @param tags 音乐风格
|
||||
* @param title 音乐名称
|
||||
* @param model 模型
|
||||
* @param waitAudio false 表示后台模式,仅返回音频任务信息,需要调用 get API 获取详细的音频信息。
|
||||
* true 表示同步模式,API 最多等待 100s,音频生成完毕后直接返回音频链接等信息,建议在 GPT 等 agent 中使用。
|
||||
* @param makeInstrumental 指示音乐音频是否为定制,如果为 true,则从歌词生成,否则从提示生成
|
||||
*/
|
||||
@JsonInclude(value = JsonInclude.Include.NON_NULL)
|
||||
public record MusicGenerateRequest(
|
||||
String prompt,
|
||||
String tags,
|
||||
String title,
|
||||
String model,
|
||||
@JsonProperty("wait_audio") boolean waitAudio,
|
||||
@JsonProperty("make_instrumental") boolean makeInstrumental
|
||||
) {
|
||||
|
||||
public MusicGenerateRequest(String prompt) {
|
||||
this(prompt, null, null, null, false, false);
|
||||
}
|
||||
|
||||
public MusicGenerateRequest(String prompt, String model, boolean makeInstrumental) {
|
||||
this(prompt, null, null, model, false, makeInstrumental);
|
||||
}
|
||||
|
||||
public MusicGenerateRequest(String prompt, String model, String tags, String title) {
|
||||
this(prompt, tags, title, model, false, false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的音频数据
|
||||
*
|
||||
* @param id 音乐数据的 ID
|
||||
* @param title 音乐音频的标题
|
||||
* @param imageUrl 音乐音频的图片 URL
|
||||
* @param lyric 音乐音频的歌词
|
||||
* @param audioUrl 音乐音频的 URL
|
||||
* @param videoUrl 音乐视频的 URL
|
||||
* @param createdAt 音乐音频的创建时间
|
||||
* @param modelName 模型名称
|
||||
* @param status submitted、queued、streaming、complete
|
||||
* @param gptDescriptionPrompt 描述词
|
||||
* @param prompt 生成音乐音频的提示
|
||||
* @param type 操作类型
|
||||
* @param tags 音乐类型标签
|
||||
* @param duration 音乐时长
|
||||
*/
|
||||
public record MusicData(
|
||||
String id,
|
||||
String title,
|
||||
@JsonProperty("image_url") String imageUrl,
|
||||
String lyric,
|
||||
@JsonProperty("audio_url") String audioUrl,
|
||||
@JsonProperty("video_url") String videoUrl,
|
||||
@JsonProperty("created_at") String createdAt,
|
||||
@JsonProperty("model_name") String modelName,
|
||||
String status,
|
||||
@JsonProperty("gpt_description_prompt") String gptDescriptionPrompt,
|
||||
@JsonProperty("error_message") String errorMessage,
|
||||
String prompt,
|
||||
String type,
|
||||
String tags,
|
||||
Double duration
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的歌词数据。
|
||||
*
|
||||
* @param text 歌词
|
||||
* @param title 标题
|
||||
* @param status 状态
|
||||
*/
|
||||
public record LyricsData(
|
||||
String text,
|
||||
String title,
|
||||
String status
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Suno API 响应的限额数据,目前每日免费 50
|
||||
*/
|
||||
public record LimitUsageData(
|
||||
@JsonProperty("credits_left") Long creditsLeft,
|
||||
String period,
|
||||
@JsonProperty("monthly_limit") Long monthlyLimit,
|
||||
@JsonProperty("monthly_usage") Long monthlyUsage
|
||||
) {
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions.MODEL_DEFAULT;
|
||||
|
||||
/**
|
||||
* 讯飞星火 {@link ChatModel} 实现类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Slf4j
|
||||
public class XingHuoChatModel implements ChatModel {
|
||||
|
||||
private static final String BASE_URL = "https://spark-api-open.xf-yun.com";
|
||||
|
||||
private final XingHuoChatOptions defaultOptions;
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
/**
|
||||
* 星火兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
|
||||
*
|
||||
* 不过要注意,星火没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
|
||||
*/
|
||||
private final OpenAiApi openAiApi;
|
||||
|
||||
public XingHuoChatModel(String apiKey, String secretKey) {
|
||||
this(apiKey, secretKey,
|
||||
XingHuoChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
|
||||
}
|
||||
|
||||
public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options) {
|
||||
this(apiKey, secretKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public XingHuoChatModel(String apiKey, String secretKey, XingHuoChatOptions options, RetryTemplate retryTemplate) {
|
||||
Assert.notEmpty(apiKey, "apiKey 不能为空");
|
||||
Assert.notEmpty(secretKey, "secretKey 不能为空");
|
||||
Assert.notNull(options, "options 不能为空");
|
||||
Assert.notNull(retryTemplate, "retryTemplate 不能为空");
|
||||
this.openAiApi = new OpenAiApi(BASE_URL, apiKey + ":" + secretKey);
|
||||
this.defaultOptions = options;
|
||||
this.retryTemplate = retryTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
|
||||
return this.retryTemplate.execute(ctx -> {
|
||||
// 1.1 发起调用
|
||||
ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
|
||||
// 1.2 校验结果
|
||||
OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
|
||||
if (chatCompletion == null) {
|
||||
log.warn("No chat completion returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
|
||||
if (choices == null) {
|
||||
log.warn("No choices returned for prompt: {}", prompt);
|
||||
return new ChatResponse(List.of());
|
||||
}
|
||||
|
||||
// 2. 转换 ChatResponse 返回
|
||||
List<Generation> generations = choices.stream().map(choice -> {
|
||||
Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
|
||||
if (choice.finishReason() != null) {
|
||||
generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
||||
}
|
||||
return generation;
|
||||
}).toList();
|
||||
return new ChatResponse(generations,
|
||||
OpenAiChatResponseMetadata.from(completionEntity.getBody()));
|
||||
});
|
||||
}
|
||||
|
||||
private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
OpenAiApi.ChatCompletionMessage message = choice.message();
|
||||
if (message.role() != null) {
|
||||
map.put("role", message.role().name());
|
||||
}
|
||||
if (choice.finishReason() != null) {
|
||||
map.put("finishReason", choice.finishReason().name());
|
||||
}
|
||||
map.put("id", id);
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
|
||||
return this.retryTemplate.execute(ctx -> {
|
||||
// 1. 发起调用
|
||||
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
|
||||
return response.map(chatCompletion -> {
|
||||
String id = chatCompletion.id();
|
||||
// 2. 转换 ChatResponse 返回
|
||||
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
|
||||
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
|
||||
Generation generation = new Generation(choice.delta().content(),
|
||||
Map.of("id", id, "role", choice.delta().role().name(), "finishReason", finish));
|
||||
if (choice.finishReason() != null) {
|
||||
generation = generation.withGenerationMetadata(
|
||||
ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
||||
}
|
||||
return generation;
|
||||
}).toList();
|
||||
return new ChatResponse(generations);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
||||
// 1. 构建 ChatCompletionMessage 对象
|
||||
List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
|
||||
new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
|
||||
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
|
||||
|
||||
// 2.1 补充 prompt 内置的 options
|
||||
if (prompt.getOptions() != null) {
|
||||
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
|
||||
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
|
||||
ChatOptions.class, OpenAiChatOptions.class);
|
||||
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
|
||||
+ prompt.getOptions().getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
// 2.2 补充默认 options
|
||||
if (this.defaultOptions != null) {
|
||||
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return XingHuoChatOptions.fromOptions(defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
|
||||
/**
|
||||
* 讯飞星火 {@link ChatOptions} 实现类
|
||||
*
|
||||
* 参考文档:<a href="https://www.xfyun.cn/doc/spark/HTTP%E8%B0%83%E7%94%A8%E6%96%87%E6%A1%A3.html">HTTP 调用</a>
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class XingHuoChatOptions implements ChatOptions {
|
||||
|
||||
public static final String MODEL_DEFAULT = "generalv3.5";
|
||||
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
private String model;
|
||||
/**
|
||||
* 温度
|
||||
*/
|
||||
private Float temperature;
|
||||
/**
|
||||
* 最大 Token
|
||||
*/
|
||||
private Integer maxTokens;
|
||||
/**
|
||||
* K 个候选
|
||||
*/
|
||||
private Integer topK;
|
||||
|
||||
@Override
|
||||
public Float getTopP() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static XingHuoChatOptions fromOptions(XingHuoChatOptions fromOptions) {
|
||||
return XingHuoChatOptions.builder()
|
||||
.model(fromOptions.getModel())
|
||||
.temperature(fromOptions.getTemperature())
|
||||
.maxTokens(fromOptions.getMaxTokens())
|
||||
.topK(fromOptions.getTopK())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.util;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
|
||||
/**
|
||||
* Spring AI 工具类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AiUtils {
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||
case YI_YAN:
|
||||
return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case DEEP_SEEK:
|
||||
return DeepSeekChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
|
||||
case ZHI_PU:
|
||||
return ZhiPuAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case XING_HUO:
|
||||
return XingHuoChatOptions.builder().model(model).temperature(temperatureF).maxTokens(maxTokens).build();
|
||||
case OPENAI:
|
||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
public static Message buildMessage(String type, String content) {
|
||||
if (MessageType.USER.getValue().equals(type)) {
|
||||
return new UserMessage(content);
|
||||
}
|
||||
if (MessageType.ASSISTANT.getValue().equals(type)) {
|
||||
return new AssistantMessage(content);
|
||||
}
|
||||
if (MessageType.SYSTEM.getValue().equals(type)) {
|
||||
return new SystemMessage(content);
|
||||
}
|
||||
if (MessageType.FUNCTION.getValue().equals(type)) {
|
||||
return new FunctionMessage(content);
|
||||
}
|
||||
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* AI 大模型组件,基于 Spring AI 拓展
|
||||
*
|
||||
* models 包路径:
|
||||
* 1. xinghuo 包:【讯飞】星火,自己实现
|
||||
* 2. deepseek 包:【深度求索】DeepSeek,自己实现
|
||||
* 3. midjourney 包:Midjourney API,对接 https://github.com/novicezk/midjourney-proxy 实现
|
||||
* 4. suno 包:Suno API,对接 https://github.com/gcui-art/suno-api 实现
|
||||
*/
|
||||
package cn.iocoder.yudao.framework.ai;
|
||||
@@ -0,0 +1,253 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechModel;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionModel;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants;
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
|
||||
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingModel;
|
||||
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingProperties;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.audio.asr.transcription.Transcription;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
|
||||
import com.alibaba.dashscope.common.MessageManager;
|
||||
import com.alibaba.dashscope.embeddings.TextEmbedding;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.utils.ApiKey;
|
||||
import com.alibaba.dashscope.utils.Constants;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Scope;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@AutoConfiguration
|
||||
@ConditionalOnClass({
|
||||
MessageManager.class,
|
||||
TongYiChatModel.class,
|
||||
TongYiImagesModel.class,
|
||||
TongYiAudioSpeechModel.class,
|
||||
TongYiTextEmbeddingModel.class,
|
||||
TongYiAudioTranscriptionModel.class
|
||||
})
|
||||
@EnableConfigurationProperties({
|
||||
TongYiChatProperties.class,
|
||||
TongYiImagesProperties.class,
|
||||
TongYiAudioSpeechProperties.class,
|
||||
TongYiConnectionProperties.class,
|
||||
TongYiTextEmbeddingProperties.class,
|
||||
TongYiAudioTranscriptionProperties.class
|
||||
})
|
||||
public class TongYiAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
@Scope("prototype")
|
||||
@ConditionalOnMissingBean
|
||||
public Generation generation() {
|
||||
|
||||
return new Generation();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Scope("prototype")
|
||||
@ConditionalOnMissingBean
|
||||
public MessageManager msgManager() {
|
||||
|
||||
return new MessageManager(10);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Scope("prototype")
|
||||
@ConditionalOnMissingBean
|
||||
public ImageSynthesis imageSynthesis() {
|
||||
|
||||
return new ImageSynthesis();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Scope("prototype")
|
||||
@ConditionalOnMissingBean
|
||||
public SpeechSynthesizer speechSynthesizer() {
|
||||
|
||||
return new SpeechSynthesizer();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public Transcription transcription() {
|
||||
|
||||
return new Transcription();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public TextEmbedding textEmbedding() {
|
||||
|
||||
return new TextEmbedding();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
|
||||
|
||||
FunctionCallbackContext manager = new FunctionCallbackContext();
|
||||
manager.setApplicationContext(context);
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(
|
||||
prefix = TongYiChatProperties.CONFIG_PREFIX,
|
||||
name = "enabled",
|
||||
havingValue = "true",
|
||||
matchIfMissing = true
|
||||
)
|
||||
public TongYiChatModel tongYiChatClient(Generation generation,
|
||||
TongYiChatProperties chatOptions,
|
||||
TongYiConnectionProperties connectionProperties
|
||||
) {
|
||||
|
||||
settingApiKey(connectionProperties);
|
||||
|
||||
return new TongYiChatModel(generation, chatOptions.getOptions());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(
|
||||
prefix = TongYiImagesProperties.CONFIG_PREFIX,
|
||||
name = "enabled",
|
||||
havingValue = "true",
|
||||
matchIfMissing = true
|
||||
)
|
||||
public TongYiImagesModel tongYiImagesClient(
|
||||
ImageSynthesis imageSynthesis,
|
||||
TongYiImagesProperties imagesOptions,
|
||||
TongYiConnectionProperties connectionProperties
|
||||
) {
|
||||
|
||||
settingApiKey(connectionProperties);
|
||||
|
||||
return new TongYiImagesModel(imageSynthesis, imagesOptions.getOptions());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(
|
||||
prefix = TongYiAudioSpeechProperties.CONFIG_PREFIX,
|
||||
name = "enabled",
|
||||
havingValue = "true",
|
||||
matchIfMissing = true
|
||||
)
|
||||
public TongYiAudioSpeechModel tongYiAudioSpeechClient(
|
||||
SpeechSynthesizer speechSynthesizer,
|
||||
TongYiAudioSpeechProperties speechProperties,
|
||||
TongYiConnectionProperties connectionProperties
|
||||
) {
|
||||
|
||||
settingApiKey(connectionProperties);
|
||||
|
||||
return new TongYiAudioSpeechModel(speechSynthesizer, speechProperties.getOptions());
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(
|
||||
prefix = TongYiAudioTranscriptionProperties.CONFIG_PREFIX,
|
||||
name = "enabled",
|
||||
havingValue = "true",
|
||||
matchIfMissing = true
|
||||
)
|
||||
public TongYiAudioTranscriptionModel tongYiAudioTranscriptionClient(
|
||||
Transcription transcription,
|
||||
TongYiAudioTranscriptionProperties transcriptionProperties,
|
||||
TongYiConnectionProperties connectionProperties) {
|
||||
|
||||
settingApiKey(connectionProperties);
|
||||
|
||||
return new TongYiAudioTranscriptionModel(
|
||||
transcriptionProperties.getOptions(),
|
||||
transcription
|
||||
);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(
|
||||
prefix = TongYiTextEmbeddingProperties.CONFIG_PREFIX,
|
||||
name = "enabled",
|
||||
havingValue = "true",
|
||||
matchIfMissing = true
|
||||
)
|
||||
public TongYiTextEmbeddingModel tongYiTextEmbeddingClient(
|
||||
TextEmbedding textEmbedding,
|
||||
TongYiConnectionProperties connectionProperties
|
||||
) {
|
||||
|
||||
settingApiKey(connectionProperties);
|
||||
return new TongYiTextEmbeddingModel(textEmbedding);
|
||||
}
|
||||
|
||||
/**
|
||||
* Setting the API key.
|
||||
* @param connectionProperties {@link TongYiConnectionProperties}
|
||||
*/
|
||||
private void settingApiKey(TongYiConnectionProperties connectionProperties) {
|
||||
|
||||
String apiKey;
|
||||
|
||||
try {
|
||||
// It is recommended to set the key by defining the api-key in an environment variable.
|
||||
var envKey = System.getenv(TongYiConstants.SCA_AI_TONGYI_API_KEY);
|
||||
if (Objects.nonNull(envKey)) {
|
||||
Constants.apiKey = envKey;
|
||||
return;
|
||||
}
|
||||
if (Objects.nonNull(connectionProperties.getApiKey())) {
|
||||
apiKey = connectionProperties.getApiKey();
|
||||
}
|
||||
else {
|
||||
apiKey = ApiKey.getApiKey(null);
|
||||
}
|
||||
|
||||
Constants.apiKey = apiKey;
|
||||
}
|
||||
catch (NoApiKeyException e) {
|
||||
|
||||
throw new TongYiException(e.getMessage());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI TongYi LLM connection properties.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiConnectionProperties.CONFIG_PREFIX)
|
||||
public class TongYiConnectionProperties {
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI connection configuration Prefix.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "tongyi";
|
||||
|
||||
/**
|
||||
* TongYi LLM API key.
|
||||
*/
|
||||
private String apiKey;
|
||||
|
||||
public String getApiKey() {
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
public void setApiKey(String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio;
|
||||
|
||||
/**
|
||||
* More models see: https://help.aliyun.com/zh/dashscope/model-list?spm=a2c4g.11186623.0.i5
|
||||
* Support all models in list.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public final class AudioSpeechModels {
|
||||
|
||||
private AudioSpeechModels() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Male Voice of the Tongue(舌尖男声).
|
||||
* zh & en.
|
||||
* Default sample rate: 48 Hz.
|
||||
*/
|
||||
public static final String SAMBERT_ZHICHU_V1 = "sambert-zhichu-v1";
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public final class AudioTranscriptionModels {
|
||||
|
||||
private AudioTranscriptionModels() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Paraformer Chinese and English speech recognition model supports audio or video speech recognition with a sampling rate of 16kHz or above.
|
||||
*/
|
||||
public static final String Paraformer_V1 = "paraformer-v1";
|
||||
/**
|
||||
* Paraformer Chinese speech recognition model, support 8kHz telephone speech recognition.
|
||||
*/
|
||||
public static final String Paraformer_8K_V1 = "paraformer-8k-v1";
|
||||
/**
|
||||
* The Paraformer multilingual speech recognition model supports audio or video speech recognition with a sample rate of 16kHz or above.
|
||||
*/
|
||||
public static final String Paraformer_MTL_V1 = "paraformer-mtl-v1";
|
||||
|
||||
}
|
||||
@@ -0,0 +1,228 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.speech.api.*;
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisParam;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
|
||||
import com.alibaba.dashscope.common.ResultCallback;
|
||||
import io.reactivex.Flowable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.Assert;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* TongYiAudioSpeechClient is a client for TongYi audio speech service for Spring Cloud Alibaba AI.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioSpeechModel implements SpeechModel, SpeechStreamModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
/**
|
||||
* Default speed rate.
|
||||
*/
|
||||
private static final float SPEED_RATE = 1.0f;
|
||||
|
||||
/**
|
||||
* TongYi models api.
|
||||
*/
|
||||
private final SpeechSynthesizer speechSynthesizer;
|
||||
|
||||
/**
|
||||
* TongYi models options.
|
||||
*/
|
||||
private final TongYiAudioSpeechOptions defaultOptions;
|
||||
|
||||
/**
|
||||
* TongYiAudioSpeechClient constructor.
|
||||
* @param speechSynthesizer the speech synthesizer
|
||||
*/
|
||||
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer) {
|
||||
|
||||
this(speechSynthesizer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* TongYiAudioSpeechClient constructor.
|
||||
* @param speechSynthesizer the speech synthesizer
|
||||
* @param tongYiAudioOptions the tongYi audio options
|
||||
*/
|
||||
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer, TongYiAudioSpeechOptions tongYiAudioOptions) {
|
||||
|
||||
Assert.notNull(speechSynthesizer, "speechSynthesizer must not be null");
|
||||
Assert.notNull(tongYiAudioOptions, "tongYiAudioOptions must not be null");
|
||||
|
||||
this.speechSynthesizer = speechSynthesizer;
|
||||
this.defaultOptions = tongYiAudioOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Call the TongYi audio speech service.
|
||||
* @param text the text message to be converted to audio.
|
||||
* @return the audio byte buffer.
|
||||
*/
|
||||
@Override
|
||||
public ByteBuffer call(String text) {
|
||||
|
||||
var speechRequest = new SpeechPrompt(text);
|
||||
|
||||
return call(speechRequest).getResult().getOutput();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Call the TongYi audio speech service.
|
||||
* @param prompt the speech prompt.
|
||||
* @return the speech response.
|
||||
*/
|
||||
@Override
|
||||
public SpeechResponse call(SpeechPrompt prompt) {
|
||||
|
||||
var SCASpeechParam = merge(prompt.getOptions());
|
||||
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
|
||||
speechSynthesisParams.setText(prompt.getInstructions().getText());
|
||||
logger.info(speechSynthesisParams.toString());
|
||||
|
||||
var res = speechSynthesizer.call(speechSynthesisParams);
|
||||
|
||||
return convert(res, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Call the TongYi audio speech service.
|
||||
* @param prompt the speech prompt.
|
||||
* @param callback the result callback.
|
||||
* {@link SpeechSynthesizer#call(SpeechSynthesisParam, ResultCallback)}
|
||||
*/
|
||||
public void call(SpeechPrompt prompt, ResultCallback<SpeechSynthesisResult> callback) {
|
||||
|
||||
var SCASpeechParam = merge(prompt.getOptions());
|
||||
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
|
||||
speechSynthesisParams.setText(prompt.getInstructions().getText());
|
||||
|
||||
speechSynthesizer.call(speechSynthesisParams, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream the TongYi audio speech service.
|
||||
* @param prompt the speech prompt.
|
||||
* @return the speech response.
|
||||
* {@link SpeechSynthesizer#streamCall(SpeechSynthesisParam)}
|
||||
*/
|
||||
@Override
|
||||
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
|
||||
|
||||
var SCASpeechParam = merge(prompt.getOptions());
|
||||
|
||||
Flowable<SpeechSynthesisResult> resultFlowable = speechSynthesizer
|
||||
.streamCall(toSpeechSynthesisParams(SCASpeechParam));
|
||||
|
||||
return Flux.from(resultFlowable)
|
||||
.flatMap(
|
||||
res -> Flux.just(res.getAudioFrame())
|
||||
.map(audio -> {
|
||||
var speech = new Speech(audio);
|
||||
var respMetadata = TongYiAudioSpeechResponseMetadata.from(res);
|
||||
return new SpeechResponse(speech, respMetadata);
|
||||
})
|
||||
).publishOn(Schedulers.parallel());
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechOptions merge(TongYiAudioSpeechOptions target) {
|
||||
|
||||
var mergeBuilder = TongYiAudioSpeechOptions.builder();
|
||||
|
||||
mergeBuilder.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel() : target.getModel());
|
||||
mergeBuilder.withPitch(defaultOptions.getPitch() != null ? defaultOptions.getPitch() : target.getPitch());
|
||||
mergeBuilder.withRate(defaultOptions.getRate() != null ? defaultOptions.getRate() : target.getRate());
|
||||
mergeBuilder.withFormat(defaultOptions.getFormat() != null ? defaultOptions.getFormat() : target.getFormat());
|
||||
mergeBuilder.withSampleRate(defaultOptions.getSampleRate() != null ? defaultOptions.getSampleRate() : target.getSampleRate());
|
||||
mergeBuilder.withTextType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : target.getTextType());
|
||||
mergeBuilder.withVolume(defaultOptions.getVolume() != null ? defaultOptions.getVolume() : target.getVolume());
|
||||
mergeBuilder.withEnablePhonemeTimestamp(defaultOptions.isEnablePhonemeTimestamp() != null ? defaultOptions.isEnablePhonemeTimestamp() : target.isEnablePhonemeTimestamp());
|
||||
mergeBuilder.withEnableWordTimestamp(defaultOptions.isEnableWordTimestamp() != null ? defaultOptions.isEnableWordTimestamp() : target.isEnableWordTimestamp());
|
||||
|
||||
return mergeBuilder.build();
|
||||
}
|
||||
|
||||
public SpeechSynthesisParam toSpeechSynthesisParams(TongYiAudioSpeechOptions source) {
|
||||
|
||||
var mergeBuilder = SpeechSynthesisParam.builder();
|
||||
|
||||
mergeBuilder.model(source.getModel() != null ? source.getModel() : AudioSpeechModels.SAMBERT_ZHICHU_V1);
|
||||
mergeBuilder.text(source.getText() != null ? source.getText() : "");
|
||||
|
||||
if (source.getFormat() != null) {
|
||||
mergeBuilder.format(source.getFormat());
|
||||
}
|
||||
if (source.getRate() != null) {
|
||||
mergeBuilder.rate(source.getRate());
|
||||
}
|
||||
if (source.getPitch() != null) {
|
||||
mergeBuilder.pitch(source.getPitch());
|
||||
}
|
||||
if (source.getTextType() != null) {
|
||||
mergeBuilder.textType(source.getTextType());
|
||||
}
|
||||
if (source.getSampleRate() != null) {
|
||||
mergeBuilder.sampleRate(source.getSampleRate());
|
||||
}
|
||||
if (source.isEnablePhonemeTimestamp() != null) {
|
||||
mergeBuilder.enablePhonemeTimestamp(source.isEnablePhonemeTimestamp());
|
||||
}
|
||||
if (source.isEnableWordTimestamp() != null) {
|
||||
mergeBuilder.enableWordTimestamp(source.isEnableWordTimestamp());
|
||||
}
|
||||
if (source.getVolume() != null) {
|
||||
mergeBuilder.volume(source.getVolume());
|
||||
}
|
||||
|
||||
return mergeBuilder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the TongYi audio speech service result to the speech response.
|
||||
* @param result the audio byte buffer.
|
||||
* @param synthesisResult the synthesis result.
|
||||
* @return the speech response.
|
||||
*/
|
||||
private SpeechResponse convert(ByteBuffer result, SpeechSynthesisResult synthesisResult) {
|
||||
|
||||
if (synthesisResult == null) {
|
||||
|
||||
return new SpeechResponse(new Speech(result));
|
||||
}
|
||||
|
||||
var responseMetadata = TongYiAudioSpeechResponseMetadata.from(synthesisResult);
|
||||
var speech = new Speech(synthesisResult.getAudioFrame());
|
||||
|
||||
return new SpeechResponse(speech, responseMetadata);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisTextType;
|
||||
import org.springframework.ai.model.ModelOptions;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioSpeechOptions implements ModelOptions {
|
||||
|
||||
|
||||
/**
|
||||
* Audio Speech models.
|
||||
*/
|
||||
private String model = AudioSpeechModels.SAMBERT_ZHICHU_V1;
|
||||
|
||||
/**
|
||||
* Text content.
|
||||
*/
|
||||
private String text;
|
||||
|
||||
/**
|
||||
* Input text type.
|
||||
*/
|
||||
private SpeechSynthesisTextType textType = SpeechSynthesisTextType.PLAIN_TEXT;
|
||||
|
||||
/**
|
||||
* synthesis audio format.
|
||||
*/
|
||||
private SpeechSynthesisAudioFormat format = SpeechSynthesisAudioFormat.WAV;
|
||||
|
||||
/**
|
||||
* synthesis audio sample rate.
|
||||
*/
|
||||
private Integer sampleRate = 16000;
|
||||
|
||||
/**
|
||||
* synthesis audio volume.
|
||||
*/
|
||||
private Integer volume = 50;
|
||||
|
||||
/**
|
||||
* synthesis audio speed.
|
||||
*/
|
||||
private Float rate = 1.0f;
|
||||
|
||||
/**
|
||||
* synthesis audio pitch.
|
||||
*/
|
||||
private Float pitch = 1.0f;
|
||||
|
||||
/**
|
||||
* enable word level timestamp.
|
||||
*/
|
||||
private Boolean enableWordTimestamp = false;
|
||||
|
||||
/**
|
||||
* enable phoneme level timestamp.
|
||||
*/
|
||||
private Boolean enablePhonemeTimestamp = false;
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
public void setModel(String model) {
|
||||
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public String getText() {
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
public void setText(String text) {
|
||||
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
public SpeechSynthesisTextType getTextType() {
|
||||
|
||||
return textType;
|
||||
}
|
||||
|
||||
public void setTextType(SpeechSynthesisTextType textType) {
|
||||
|
||||
this.textType = textType;
|
||||
}
|
||||
|
||||
public SpeechSynthesisAudioFormat getFormat() {
|
||||
|
||||
return format;
|
||||
}
|
||||
|
||||
public void setFormat(SpeechSynthesisAudioFormat format) {
|
||||
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
public Integer getSampleRate() {
|
||||
|
||||
return sampleRate;
|
||||
}
|
||||
|
||||
public void setSampleRate(Integer sampleRate) {
|
||||
|
||||
this.sampleRate = sampleRate;
|
||||
}
|
||||
|
||||
public Integer getVolume() {
|
||||
|
||||
return volume;
|
||||
}
|
||||
|
||||
public void setVolume(Integer volume) {
|
||||
|
||||
this.volume = volume;
|
||||
}
|
||||
|
||||
public Float getRate() {
|
||||
|
||||
return rate;
|
||||
}
|
||||
|
||||
public void setRate(Float rate) {
|
||||
|
||||
this.rate = rate;
|
||||
}
|
||||
|
||||
public Float getPitch() {
|
||||
|
||||
return pitch;
|
||||
}
|
||||
|
||||
public void setPitch(Float pitch) {
|
||||
|
||||
this.pitch = pitch;
|
||||
}
|
||||
|
||||
public Boolean isEnableWordTimestamp() {
|
||||
|
||||
return enableWordTimestamp;
|
||||
}
|
||||
|
||||
public void setEnableWordTimestamp(Boolean enableWordTimestamp) {
|
||||
|
||||
this.enableWordTimestamp = enableWordTimestamp;
|
||||
}
|
||||
|
||||
public Boolean isEnablePhonemeTimestamp() {
|
||||
|
||||
return enablePhonemeTimestamp;
|
||||
}
|
||||
|
||||
public void setEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
|
||||
|
||||
this.enablePhonemeTimestamp = enablePhonemeTimestamp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a options instances.
|
||||
*/
|
||||
public static class Builder {
|
||||
|
||||
private final TongYiAudioSpeechOptions options = new TongYiAudioSpeechOptions();
|
||||
|
||||
public Builder withModel(String model) {
|
||||
|
||||
options.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withText(String text) {
|
||||
|
||||
options.text = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTextType(SpeechSynthesisTextType textType) {
|
||||
|
||||
options.textType = textType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withFormat(SpeechSynthesisAudioFormat format) {
|
||||
|
||||
options.format = format;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withSampleRate(Integer sampleRate) {
|
||||
|
||||
options.sampleRate = sampleRate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withVolume(Integer volume) {
|
||||
|
||||
options.volume = volume;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withRate(Float rate) {
|
||||
|
||||
options.rate = rate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withPitch(Float pitch) {
|
||||
|
||||
options.pitch = pitch;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withEnableWordTimestamp(Boolean enableWordTimestamp) {
|
||||
|
||||
options.enableWordTimestamp = enableWordTimestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
|
||||
|
||||
options.enablePhonemeTimestamp = enablePhonemeTimestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechOptions build() {
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* TongYi audio speech configuration properties.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiAudioSpeechProperties.CONFIG_PREFIX)
|
||||
public class TongYiAudioSpeechProperties {
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI configuration prefix.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.speech";
|
||||
/**
|
||||
* Default TongYi Chat model.
|
||||
*/
|
||||
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioSpeechModels.SAMBERT_ZHICHU_V1;
|
||||
|
||||
/**
|
||||
* Enable TongYiQWEN ai audio client.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private TongYiAudioSpeechOptions options = TongYiAudioSpeechOptions.builder()
|
||||
.withModel(DEFAULT_AUDIO_MODEL_NAME)
|
||||
.withFormat(SpeechSynthesisAudioFormat.WAV)
|
||||
.build();
|
||||
|
||||
public TongYiAudioSpeechOptions getOptions() {
|
||||
|
||||
return this.options;
|
||||
}
|
||||
|
||||
public void setOptions(TongYiAudioSpeechOptions options) {
|
||||
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import org.springframework.ai.model.ModelResult;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class Speech implements ModelResult<ByteBuffer> {
|
||||
|
||||
private final ByteBuffer audio;
|
||||
|
||||
private SpeechMetadata speechMetadata;
|
||||
|
||||
public Speech(ByteBuffer audio) {
|
||||
this.audio = audio;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteBuffer getOutput() {
|
||||
return this.audio;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SpeechMetadata getMetadata() {
|
||||
|
||||
return speechMetadata != null ? speechMetadata : SpeechMetadata.NULL;
|
||||
}
|
||||
|
||||
public Speech withSpeechMetadata(@Nullable SpeechMetadata speechMetadata) {
|
||||
|
||||
this.speechMetadata = speechMetadata;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof Speech that)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return Arrays.equals(audio.array(), that.audio.array())
|
||||
&& Objects.equals(speechMetadata, that.speechMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(Arrays.hashCode(audio.array()), speechMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}';
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* The {@link SpeechMessage} class represents a single text message to
|
||||
* be converted to speech by the TongYi LLM TTS.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class SpeechMessage {
|
||||
|
||||
private String text;
|
||||
|
||||
/**
|
||||
* Constructs a new {@link SpeechMessage} object with the given text.
|
||||
* @param text the text to be converted to speech
|
||||
*/
|
||||
public SpeechMessage(String text) {
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the text of this speech message.
|
||||
* @return the text of this speech message
|
||||
*/
|
||||
public String getText() {
|
||||
return text;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the text of this speech message.
|
||||
* @param text the new text for this speech message
|
||||
*/
|
||||
public void setText(String text) {
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!(o instanceof SpeechMessage that)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return Objects.equals(text, that.text);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(text);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import org.springframework.ai.model.ResultMetadata;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public interface SpeechMetadata extends ResultMetadata {
|
||||
|
||||
/**
|
||||
* Null Object.
|
||||
*/
|
||||
SpeechMetadata NULL = SpeechMetadata.create();
|
||||
|
||||
/**
|
||||
* Factory method used to construct a new {@link SpeechMetadata}.
|
||||
* @return a new {@link SpeechMetadata}
|
||||
*/
|
||||
static SpeechMetadata create() {
|
||||
return new SpeechMetadata() {
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import org.springframework.ai.model.Model;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.0.0-RC1
|
||||
*/
|
||||
|
||||
@FunctionalInterface
|
||||
public interface SpeechModel extends Model<SpeechPrompt, SpeechResponse> {
|
||||
|
||||
/**
|
||||
* Generates spoken audio from the provided text message.
|
||||
* @param message the text message to be converted to audio.
|
||||
* @return the resulting audio bytes.
|
||||
*/
|
||||
default ByteBuffer call(String message) {
|
||||
|
||||
SpeechPrompt prompt = new SpeechPrompt(message);
|
||||
|
||||
return call(prompt).getResult().getOutput();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a speech request to the TongYi TTS API and returns the resulting speech response.
|
||||
* @param request the speech prompt containing the input text and other parameters.
|
||||
* @return the speech response containing the generated audio.
|
||||
*/
|
||||
SpeechResponse call(SpeechPrompt request);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechOptions;
|
||||
import org.springframework.ai.model.ModelRequest;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class SpeechPrompt implements ModelRequest<SpeechMessage> {
|
||||
|
||||
private TongYiAudioSpeechOptions speechOptions;
|
||||
|
||||
private final SpeechMessage message;
|
||||
|
||||
public SpeechPrompt(String instructions) {
|
||||
|
||||
this(new SpeechMessage(instructions), TongYiAudioSpeechOptions.builder().build());
|
||||
}
|
||||
|
||||
public SpeechPrompt(String instructions, TongYiAudioSpeechOptions speechOptions) {
|
||||
|
||||
this(new SpeechMessage(instructions), speechOptions);
|
||||
}
|
||||
|
||||
public SpeechPrompt(SpeechMessage speechMessage) {
|
||||
this(speechMessage, TongYiAudioSpeechOptions.builder().build());
|
||||
}
|
||||
|
||||
public SpeechPrompt(SpeechMessage speechMessage, TongYiAudioSpeechOptions speechOptions) {
|
||||
|
||||
this.message = speechMessage;
|
||||
this.speechOptions = speechOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SpeechMessage getInstructions() {
|
||||
return this.message;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TongYiAudioSpeechOptions getOptions() {
|
||||
|
||||
return speechOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof SpeechPrompt that)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return Objects.equals(speechOptions, that.speechOptions) && Objects.equals(message, that.message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(speechOptions, message);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
|
||||
import org.springframework.ai.model.ModelResponse;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class SpeechResponse implements ModelResponse<Speech> {
|
||||
|
||||
private final Speech speech;
|
||||
|
||||
private final TongYiAudioSpeechResponseMetadata speechResponseMetadata;
|
||||
|
||||
/**
|
||||
* Creates a new instance of SpeechResponse with the given speech result.
|
||||
* @param speech the speech result to be set in the SpeechResponse
|
||||
* @see Speech
|
||||
*/
|
||||
public SpeechResponse(Speech speech) {
|
||||
this(speech, TongYiAudioSpeechResponseMetadata.NULL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new instance of SpeechResponse with the given speech result and speech
|
||||
* response metadata.
|
||||
* @param speech the speech result to be set in the SpeechResponse
|
||||
* @param speechResponseMetadata the speech response metadata to be set in the
|
||||
* SpeechResponse
|
||||
* @see Speech
|
||||
* @see TongYiAudioSpeechResponseMetadata
|
||||
*/
|
||||
public SpeechResponse(Speech speech, TongYiAudioSpeechResponseMetadata speechResponseMetadata) {
|
||||
|
||||
this.speech = speech;
|
||||
this.speechResponseMetadata = speechResponseMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Speech getResult() {
|
||||
return speech;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Speech> getResults() {
|
||||
return Collections.singletonList(speech);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TongYiAudioSpeechResponseMetadata getMetadata() {
|
||||
return speechResponseMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
if (!(o instanceof SpeechResponse that)) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return Objects.equals(speech, that.speech)
|
||||
&& Objects.equals(speechResponseMetadata, that.speechResponseMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(speech, speechResponseMetadata);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
|
||||
|
||||
import org.springframework.ai.model.StreamingModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@FunctionalInterface
|
||||
public interface SpeechStreamModel extends StreamingModel<SpeechPrompt, SpeechResponse> {
|
||||
|
||||
/**
|
||||
* Generates a stream of audio bytes from the provided text message.
|
||||
*
|
||||
* @param message the text message to be converted to audio
|
||||
* @return a Flux of audio bytes representing the generated speech
|
||||
*/
|
||||
default Flux<ByteBuffer> stream(String message) {
|
||||
|
||||
SpeechPrompt prompt = new SpeechPrompt(message);
|
||||
return stream(prompt).map(SpeechResponse::getResult).map(Speech::getOutput);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a speech request to the TongYi TTS API and returns a stream of the resulting
|
||||
* speech responses.
|
||||
* @param prompt the speech prompt containing the input text and other parameters.
|
||||
* @return a Flux of speech responses, each containing a portion of the generated audio.
|
||||
*/
|
||||
@Override
|
||||
Flux<SpeechResponse> stream(SpeechPrompt prompt);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionPrompt;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResponse;
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResult;
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
|
||||
import com.alibaba.dashscope.audio.asr.transcription.*;
|
||||
import org.springframework.ai.model.Model;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* TongYiAudioTranscriptionModel is a client for TongYi audio transcription service for
|
||||
* Spring Cloud Alibaba AI.
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioTranscriptionModel
|
||||
implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
|
||||
|
||||
/**
|
||||
* TongYi models options.
|
||||
*/
|
||||
private final TongYiAudioTranscriptionOptions defaultOptions;
|
||||
|
||||
/**
|
||||
* TongYi models api.
|
||||
*/
|
||||
private final Transcription transcription;
|
||||
|
||||
public TongYiAudioTranscriptionModel(Transcription transcription) {
|
||||
this(null, transcription);
|
||||
}
|
||||
|
||||
public TongYiAudioTranscriptionModel(TongYiAudioTranscriptionOptions defaultOptions,
|
||||
Transcription transcription) {
|
||||
Assert.notNull(transcription, "transcription must not be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions must not be null");
|
||||
|
||||
this.defaultOptions = defaultOptions;
|
||||
this.transcription = transcription;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AudioTranscriptionResponse call(AudioTranscriptionPrompt prompt) {
|
||||
|
||||
TranscriptionParam transcriptionParam;
|
||||
|
||||
if (prompt.getOptions() != null) {
|
||||
var param = merge(prompt.getOptions());
|
||||
transcriptionParam = toTranscriptionParam(param);
|
||||
transcriptionParam.setFileUrls(prompt.getOptions().getFileUrls());
|
||||
}
|
||||
else {
|
||||
Resource instructions = prompt.getInstructions();
|
||||
try {
|
||||
transcriptionParam = TranscriptionParam.builder()
|
||||
.model(AudioTranscriptionModels.Paraformer_V1)
|
||||
.fileUrls(List.of(String.valueOf(instructions.getURL())))
|
||||
.build();
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new TongYiException("Failed to create resource", e);
|
||||
}
|
||||
}
|
||||
|
||||
List<TranscriptionTaskResult> taskResultList;
|
||||
try {
|
||||
// Submit a transcription request
|
||||
TranscriptionResult result = transcription.asyncCall(transcriptionParam);
|
||||
// Wait for the transcription to complete
|
||||
result = transcription.wait(TranscriptionQueryParam
|
||||
.FromTranscriptionParam(transcriptionParam, result.getTaskId()));
|
||||
// Get the transcription results
|
||||
System.out.println(result.getOutput().getAsJsonObject());
|
||||
taskResultList = result.getResults();
|
||||
System.out.println(Arrays.toString(taskResultList.toArray()));
|
||||
|
||||
return new AudioTranscriptionResponse(
|
||||
taskResultList.stream().map(taskResult ->
|
||||
new AudioTranscriptionResult(taskResult.getTranscriptionUrl())
|
||||
).collect(Collectors.toList()),
|
||||
TongYiAudioTranscriptionResponseMetadata.from(result)
|
||||
);
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new TongYiException("Failed to call audio transcription", e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public TongYiAudioTranscriptionOptions merge(TongYiAudioTranscriptionOptions target) {
|
||||
var mergeBuilder = TongYiAudioTranscriptionOptions.builder();
|
||||
|
||||
mergeBuilder
|
||||
.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel()
|
||||
: target.getModel());
|
||||
mergeBuilder.withChannelId(
|
||||
defaultOptions.getChannelId() != null ? defaultOptions.getChannelId()
|
||||
: target.getChannelId());
|
||||
mergeBuilder.withDiarizationEnabled(defaultOptions.getDiarizationEnabled() != null
|
||||
? defaultOptions.getDiarizationEnabled()
|
||||
: target.getDiarizationEnabled());
|
||||
mergeBuilder.withDisfluencyRemovalEnabled(
|
||||
defaultOptions.getDisfluencyRemovalEnabled() != null
|
||||
? defaultOptions.getDisfluencyRemovalEnabled()
|
||||
: target.getDisfluencyRemovalEnabled());
|
||||
mergeBuilder.withTimestampAlignmentEnabled(
|
||||
defaultOptions.getTimestampAlignmentEnabled() != null
|
||||
? defaultOptions.getTimestampAlignmentEnabled()
|
||||
: target.getTimestampAlignmentEnabled());
|
||||
mergeBuilder.withSpecialWordFilter(defaultOptions.getSpecialWordFilter() != null
|
||||
? defaultOptions.getSpecialWordFilter()
|
||||
: target.getSpecialWordFilter());
|
||||
mergeBuilder.withAudioEventDetectionEnabled(
|
||||
defaultOptions.getAudioEventDetectionEnabled() != null
|
||||
? defaultOptions.getAudioEventDetectionEnabled()
|
||||
: target.getAudioEventDetectionEnabled());
|
||||
|
||||
return mergeBuilder.build();
|
||||
}
|
||||
|
||||
public TranscriptionParam toTranscriptionParam(
|
||||
TongYiAudioTranscriptionOptions source) {
|
||||
var mergeBuilder = TranscriptionParam.builder();
|
||||
|
||||
mergeBuilder.model(source.getModel() != null ? source.getModel()
|
||||
: AudioTranscriptionModels.Paraformer_V1);
|
||||
mergeBuilder.fileUrls(
|
||||
source.getFileUrls() != null ? source.getFileUrls() : new ArrayList<>());
|
||||
if (source.getPhraseId() != null) {
|
||||
mergeBuilder.phraseId(source.getPhraseId());
|
||||
}
|
||||
if (source.getChannelId() != null) {
|
||||
mergeBuilder.channelId(source.getChannelId());
|
||||
}
|
||||
if (source.getDiarizationEnabled() != null) {
|
||||
mergeBuilder.diarizationEnabled(source.getDiarizationEnabled());
|
||||
}
|
||||
if (source.getSpeakerCount() != null) {
|
||||
mergeBuilder.speakerCount(source.getSpeakerCount());
|
||||
}
|
||||
if (source.getDisfluencyRemovalEnabled() != null) {
|
||||
mergeBuilder.disfluencyRemovalEnabled(source.getDisfluencyRemovalEnabled());
|
||||
}
|
||||
if (source.getTimestampAlignmentEnabled() != null) {
|
||||
mergeBuilder.timestampAlignmentEnabled(source.getTimestampAlignmentEnabled());
|
||||
}
|
||||
if (source.getSpecialWordFilter() != null) {
|
||||
mergeBuilder.specialWordFilter(source.getSpecialWordFilter());
|
||||
}
|
||||
if (source.getAudioEventDetectionEnabled() != null) {
|
||||
mergeBuilder
|
||||
.audioEventDetectionEnabled(source.getAudioEventDetectionEnabled());
|
||||
}
|
||||
|
||||
return mergeBuilder.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
|
||||
import org.springframework.ai.model.ModelOptions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioTranscriptionOptions implements ModelOptions {
|
||||
|
||||
private String model = AudioTranscriptionModels.Paraformer_V1;
|
||||
|
||||
private List<String> fileUrls = new ArrayList<>();
|
||||
|
||||
private String phraseId = null;
|
||||
|
||||
private List<Integer> channelId = Collections.singletonList(0);
|
||||
|
||||
private Boolean diarizationEnabled = false;
|
||||
|
||||
private Integer speakerCount = null;
|
||||
|
||||
private Boolean disfluencyRemovalEnabled = false;
|
||||
|
||||
private Boolean timestampAlignmentEnabled = false;
|
||||
|
||||
private String specialWordFilter = "";
|
||||
|
||||
private Boolean audioEventDetectionEnabled = false;
|
||||
|
||||
public static Builder builder() {
|
||||
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public void setModel(String model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public List<String> getFileUrls() {
|
||||
return fileUrls;
|
||||
}
|
||||
|
||||
public void setFileUrls(List<String> fileUrls) {
|
||||
this.fileUrls = fileUrls;
|
||||
}
|
||||
|
||||
public String getPhraseId() {
|
||||
return phraseId;
|
||||
}
|
||||
|
||||
public void setPhraseId(String phraseId) {
|
||||
this.phraseId = phraseId;
|
||||
}
|
||||
|
||||
public List<Integer> getChannelId() {
|
||||
return channelId;
|
||||
}
|
||||
|
||||
public void setChannelId(List<Integer> channelId) {
|
||||
this.channelId = channelId;
|
||||
}
|
||||
|
||||
public Boolean getDiarizationEnabled() {
|
||||
return diarizationEnabled;
|
||||
}
|
||||
|
||||
public void setDiarizationEnabled(Boolean diarizationEnabled) {
|
||||
this.diarizationEnabled = diarizationEnabled;
|
||||
}
|
||||
|
||||
public Integer getSpeakerCount() {
|
||||
return speakerCount;
|
||||
}
|
||||
|
||||
public void setSpeakerCount(Integer speakerCount) {
|
||||
this.speakerCount = speakerCount;
|
||||
}
|
||||
|
||||
public Boolean getDisfluencyRemovalEnabled() {
|
||||
return disfluencyRemovalEnabled;
|
||||
}
|
||||
|
||||
public void setDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
|
||||
this.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
|
||||
}
|
||||
|
||||
public Boolean getTimestampAlignmentEnabled() {
|
||||
return timestampAlignmentEnabled;
|
||||
}
|
||||
|
||||
public void setTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
|
||||
this.timestampAlignmentEnabled = timestampAlignmentEnabled;
|
||||
}
|
||||
|
||||
public String getSpecialWordFilter() {
|
||||
return specialWordFilter;
|
||||
}
|
||||
|
||||
public void setSpecialWordFilter(String specialWordFilter) {
|
||||
this.specialWordFilter = specialWordFilter;
|
||||
}
|
||||
|
||||
public Boolean getAudioEventDetectionEnabled() {
|
||||
return audioEventDetectionEnabled;
|
||||
}
|
||||
|
||||
public void setAudioEventDetectionEnabled(Boolean audioEventDetectionEnabled) {
|
||||
this.audioEventDetectionEnabled = audioEventDetectionEnabled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builder class for constructing TongYiAudioTranscriptionOptions instances.
|
||||
*/
|
||||
public static class Builder {
|
||||
|
||||
private final TongYiAudioTranscriptionOptions options = new TongYiAudioTranscriptionOptions();
|
||||
|
||||
public Builder withModel(String model) {
|
||||
options.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withFileUrls(List<String> fileUrls) {
|
||||
options.fileUrls = fileUrls;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withPhraseId(String phraseId) {
|
||||
options.phraseId = phraseId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withChannelId(List<Integer> channelId) {
|
||||
options.channelId = channelId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withDiarizationEnabled(Boolean diarizationEnabled) {
|
||||
options.diarizationEnabled = diarizationEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withSpeakerCount(Integer speakerCount) {
|
||||
options.speakerCount = speakerCount;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
|
||||
options.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
|
||||
options.timestampAlignmentEnabled = timestampAlignmentEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withSpecialWordFilter(String specialWordFilter) {
|
||||
options.specialWordFilter = specialWordFilter;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withAudioEventDetectionEnabled(
|
||||
Boolean audioEventDetectionEnabled) {
|
||||
options.audioEventDetectionEnabled = audioEventDetectionEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiAudioTranscriptionOptions build() {
|
||||
// Perform any necessary validation here before returning the built object
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiAudioTranscriptionProperties.CONFIG_PREFIX)
|
||||
public class TongYiAudioTranscriptionProperties {
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI configuration prefix.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.transcription";
|
||||
|
||||
/**
|
||||
* Default TongYi Chat model.
|
||||
*/
|
||||
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioTranscriptionModels.Paraformer_V1;
|
||||
|
||||
/**
|
||||
* Enable TongYiQWEN ai audio client.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private TongYiAudioTranscriptionOptions options = TongYiAudioTranscriptionOptions
|
||||
.builder().withModel(DEFAULT_AUDIO_MODEL_NAME).build();
|
||||
|
||||
public TongYiAudioTranscriptionOptions getOptions() {
|
||||
|
||||
return this.options;
|
||||
}
|
||||
|
||||
public void setOptions(TongYiAudioTranscriptionOptions options) {
|
||||
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
|
||||
this.enabled = enabled;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionOptions;
|
||||
import org.springframework.ai.model.ModelRequest;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class AudioTranscriptionPrompt implements ModelRequest<Resource> {
|
||||
|
||||
private Resource audioResource;
|
||||
|
||||
private TongYiAudioTranscriptionOptions transcriptionOptions;
|
||||
|
||||
public AudioTranscriptionPrompt(Resource resource) {
|
||||
this.audioResource = resource;
|
||||
}
|
||||
|
||||
public AudioTranscriptionPrompt(Resource resource, TongYiAudioTranscriptionOptions transcriptionOptions) {
|
||||
this.audioResource = resource;
|
||||
this.transcriptionOptions = transcriptionOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Resource getInstructions() {
|
||||
|
||||
return audioResource;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TongYiAudioTranscriptionOptions getOptions() {
|
||||
|
||||
return transcriptionOptions;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
|
||||
import org.springframework.ai.model.ModelResponse;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class AudioTranscriptionResponse implements ModelResponse<AudioTranscriptionResult> {
|
||||
|
||||
private List<AudioTranscriptionResult> resultList;
|
||||
|
||||
private TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata;
|
||||
|
||||
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result) {
|
||||
|
||||
this(result, TongYiAudioTranscriptionResponseMetadata.NULL);
|
||||
}
|
||||
|
||||
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result,
|
||||
TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata) {
|
||||
|
||||
this.resultList = List.copyOf(result);
|
||||
this.transcriptionResponseMetadata = transcriptionResponseMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AudioTranscriptionResult getResult() {
|
||||
|
||||
return this.resultList.get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AudioTranscriptionResult> getResults() {
|
||||
|
||||
return this.resultList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResponseMetadata getMetadata() {
|
||||
|
||||
return this.transcriptionResponseMetadata;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionMetadata;
|
||||
import org.springframework.ai.model.ModelResult;
|
||||
import org.springframework.ai.model.ResultMetadata;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
public class AudioTranscriptionResult implements ModelResult<String> {
|
||||
|
||||
private String text;
|
||||
|
||||
private TongYiAudioTranscriptionMetadata transcriptionMetadata;
|
||||
|
||||
public AudioTranscriptionResult(String text) {
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutput() {
|
||||
|
||||
return this.text;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResultMetadata getMetadata() {
|
||||
|
||||
return transcriptionMetadata != null ? transcriptionMetadata : TongYiAudioTranscriptionMetadata.NULL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
AudioTranscriptionResult that = (AudioTranscriptionResult) o;
|
||||
return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(text, transcriptionMetadata);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,482 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.chat;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
|
||||
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
||||
import com.alibaba.dashscope.common.MessageManager;
|
||||
import com.alibaba.dashscope.common.Role;
|
||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.tools.FunctionDefinition;
|
||||
import com.alibaba.dashscope.tools.ToolCallBase;
|
||||
import com.alibaba.dashscope.tools.ToolCallFunction;
|
||||
import com.alibaba.dashscope.utils.ApiKeywords;
|
||||
import com.alibaba.dashscope.utils.JsonUtils;
|
||||
import io.reactivex.Flowable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
|
||||
* backed by {@link Generation}.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
* @see ChatModel
|
||||
* @see com.alibaba.dashscope.aigc.generation
|
||||
*/
|
||||
|
||||
public class TongYiChatModel extends
|
||||
AbstractFunctionCallSupport<
|
||||
com.alibaba.dashscope.common.Message,
|
||||
ConversationParam,
|
||||
GenerationResult>
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(TongYiChatModel.class);
|
||||
|
||||
/**
|
||||
* DashScope generation client.
|
||||
*/
|
||||
private final Generation generation;
|
||||
|
||||
/**
|
||||
* The TongYi models default chat completion api.
|
||||
*/
|
||||
private TongYiChatOptions defaultOptions;
|
||||
|
||||
/**
|
||||
* User role message manager.
|
||||
*/
|
||||
@Autowired
|
||||
private MessageManager msgManager;
|
||||
|
||||
/**
|
||||
* Initializes an instance of the TongYiChatClient.
|
||||
* @param generation DashScope generation client.
|
||||
*/
|
||||
public TongYiChatModel(Generation generation) {
|
||||
|
||||
this(generation,
|
||||
TongYiChatOptions.builder()
|
||||
.withTopP(0.8)
|
||||
.withEnableSearch(true)
|
||||
.withResultFormat(ConversationParam.ResultFormat.MESSAGE)
|
||||
.build(),
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an instance of the TongYiChatClient.
|
||||
* @param generation DashScope generation client.
|
||||
* @param options TongYi model params.
|
||||
*/
|
||||
public TongYiChatModel(Generation generation, TongYiChatOptions options) {
|
||||
|
||||
this(generation, options, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TongYi models client.
|
||||
* @param generation DashScope model generation client.
|
||||
* @param options TongYi default chat completion api.
|
||||
*/
|
||||
public TongYiChatModel(Generation generation, TongYiChatOptions options,
|
||||
FunctionCallbackContext functionCallbackContext) {
|
||||
|
||||
super(functionCallbackContext);
|
||||
this.generation = generation;
|
||||
this.defaultOptions = options;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default sca chat options.
|
||||
*
|
||||
* @return TongYiChatOptions default object.
|
||||
*/
|
||||
public TongYiChatOptions getDefaultOptions() {
|
||||
|
||||
return this.defaultOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
|
||||
ConversationParam params = toTongYiChatParams(prompt);
|
||||
|
||||
// TongYi models context loader.
|
||||
com.alibaba.dashscope.common.Message message = new com.alibaba.dashscope.common.Message();
|
||||
message.setRole(Role.USER.getValue());
|
||||
message.setContent(prompt.getContents());
|
||||
msgManager.add(message);
|
||||
params.setMessages(msgManager.get());
|
||||
|
||||
logger.trace("TongYi ConversationOptions: {}", params);
|
||||
GenerationResult chatCompletions = this.callWithFunctionSupport(params);
|
||||
logger.trace("TongYi ConversationOptions: {}", params);
|
||||
|
||||
msgManager.add(chatCompletions);
|
||||
|
||||
List<org.springframework.ai.chat.model.Generation> generations =
|
||||
chatCompletions
|
||||
.getOutput()
|
||||
.getChoices()
|
||||
.stream()
|
||||
.map(choice ->
|
||||
new org.springframework.ai.chat.model.Generation(
|
||||
choice
|
||||
.getMessage()
|
||||
.getContent()
|
||||
).withGenerationMetadata(generateChoiceMetadata(choice)
|
||||
))
|
||||
.toList();
|
||||
|
||||
return new ChatResponse(generations);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
|
||||
Flowable<GenerationResult> genRes;
|
||||
ConversationParam tongYiChatParams = toTongYiChatParams(prompt);
|
||||
|
||||
// See https://help.aliyun.com/zh/dashscope/developer-reference/api-details?spm=a2c4g.11186623.0.0.655fc11aRR0jj7#b9ad0a10cfhpe
|
||||
// tongYiChatParams.setIncrementalOutput(true);
|
||||
|
||||
try {
|
||||
genRes = generation.streamCall(tongYiChatParams);
|
||||
}
|
||||
catch (NoApiKeyException | InputRequiredException e) {
|
||||
logger.warn("TongYi chat client: " + e.getMessage());
|
||||
throw new TongYiException(e.getMessage());
|
||||
}
|
||||
|
||||
return Flux.from(genRes)
|
||||
.flatMap(
|
||||
message -> Flux.just(
|
||||
message.getOutput()
|
||||
.getChoices()
|
||||
.get(0)
|
||||
.getMessage()
|
||||
.getContent())
|
||||
.map(content -> {
|
||||
var gen = new org.springframework.ai.chat.model.Generation(content)
|
||||
.withGenerationMetadata(generateChoiceMetadata(
|
||||
message.getOutput()
|
||||
.getChoices()
|
||||
.get(0)
|
||||
));
|
||||
return new ChatResponse(List.of(gen));
|
||||
})
|
||||
)
|
||||
.publishOn(Schedulers.parallel());
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration properties to Qwen model params.
|
||||
* Test access.
|
||||
*
|
||||
* @param prompt {@link Prompt}
|
||||
* @return Qwen models params {@link ConversationParam}
|
||||
*/
|
||||
public ConversationParam toTongYiChatParams(Prompt prompt) {
|
||||
|
||||
Set<String> functionsForThisRequest = new HashSet<>();
|
||||
|
||||
List<com.alibaba.dashscope.common.Message> tongYiMessage = prompt.getInstructions().stream()
|
||||
.map(this::fromSpringAIMessage)
|
||||
.toList();
|
||||
|
||||
ConversationParam chatParams = ConversationParam.builder()
|
||||
.messages(tongYiMessage)
|
||||
// models setting
|
||||
// {@link HalfDuplexServiceParam#models}
|
||||
.model(Generation.Models.QWEN_TURBO)
|
||||
// {@link GenerationOutput}
|
||||
.resultFormat(ConversationParam.ResultFormat.MESSAGE)
|
||||
.incrementalOutput(true)
|
||||
|
||||
.build();
|
||||
|
||||
if (this.defaultOptions != null) {
|
||||
|
||||
chatParams = merge(chatParams, this.defaultOptions);
|
||||
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, !IS_RUNTIME_CALL);
|
||||
functionsForThisRequest.addAll(defaultEnabledFunctions);
|
||||
}
|
||||
if (prompt.getOptions() != null) {
|
||||
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
|
||||
TongYiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
|
||||
ChatOptions.class, TongYiChatOptions.class);
|
||||
|
||||
chatParams = merge(updatedRuntimeOptions, chatParams);
|
||||
|
||||
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
|
||||
IS_RUNTIME_CALL);
|
||||
functionsForThisRequest.addAll(promptEnabledFunctions);
|
||||
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Prompt options are not of type ConversationParam:"
|
||||
+ prompt.getOptions().getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
|
||||
// Add the enabled functions definitions to the request's tools parameter.
|
||||
|
||||
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
|
||||
List<FunctionDefinition> tools = this.getFunctionTools(functionsForThisRequest);
|
||||
|
||||
// todo chatParams.setTools(tools)
|
||||
}
|
||||
|
||||
return chatParams;
|
||||
}
|
||||
|
||||
private ChatGenerationMetadata generateChoiceMetadata(GenerationOutput.Choice choice) {
|
||||
|
||||
return ChatGenerationMetadata.from(
|
||||
String.valueOf(choice.getFinishReason()),
|
||||
choice.getMessage().getContent()
|
||||
);
|
||||
}
|
||||
|
||||
private List<FunctionDefinition> getFunctionTools(Set<String> functionNames) {
|
||||
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
|
||||
|
||||
FunctionDefinition functionDefinition = FunctionDefinition.builder()
|
||||
.name(functionCallback.getName())
|
||||
.description(functionCallback.getDescription())
|
||||
.parameters(JsonUtils.parametersToJsonObject(
|
||||
ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())
|
||||
))
|
||||
.build();
|
||||
|
||||
return functionDefinition;
|
||||
}).toList();
|
||||
}
|
||||
|
||||
|
||||
private ConversationParam merge(ConversationParam tongYiParams, TongYiChatOptions scaChatParams) {
|
||||
|
||||
if (scaChatParams == null) {
|
||||
|
||||
return tongYiParams;
|
||||
}
|
||||
|
||||
return ConversationParam.builder()
|
||||
.messages(tongYiParams.getMessages())
|
||||
.maxTokens((tongYiParams.getMaxTokens() != null) ? tongYiParams.getMaxTokens() : scaChatParams.getMaxTokens())
|
||||
// When merge options. Because ConversationParams is must not null. So is setting.
|
||||
.model(scaChatParams.getModel())
|
||||
.resultFormat((tongYiParams.getResultFormat() != null) ? tongYiParams.getResultFormat() : scaChatParams.getResultFormat())
|
||||
.enableSearch((tongYiParams.getEnableSearch() != null) ? tongYiParams.getEnableSearch() : scaChatParams.getEnableSearch())
|
||||
.topK((tongYiParams.getTopK() != null) ? tongYiParams.getTopK() : scaChatParams.getTopK())
|
||||
.topP((tongYiParams.getTopP() != null) ? tongYiParams.getTopP() : scaChatParams.getTopP())
|
||||
.incrementalOutput((tongYiParams.getIncrementalOutput() != null) ? tongYiParams.getIncrementalOutput() : scaChatParams.getIncrementalOutput())
|
||||
.temperature((tongYiParams.getTemperature() != null) ? tongYiParams.getTemperature() : scaChatParams.getTemperature())
|
||||
.repetitionPenalty((tongYiParams.getRepetitionPenalty() != null) ? tongYiParams.getRepetitionPenalty() : scaChatParams.getRepetitionPenalty())
|
||||
.seed((tongYiParams.getSeed() != null) ? tongYiParams.getSeed() : scaChatParams.getSeed())
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
private ConversationParam merge(TongYiChatOptions scaChatParams, ConversationParam tongYiParams) {
|
||||
|
||||
if (scaChatParams == null) {
|
||||
|
||||
return tongYiParams;
|
||||
}
|
||||
|
||||
ConversationParam mergedTongYiParams = ConversationParam.builder()
|
||||
.model(Generation.Models.QWEN_TURBO)
|
||||
.messages(tongYiParams.getMessages())
|
||||
.build();
|
||||
mergedTongYiParams = merge(tongYiParams, scaChatParams);
|
||||
|
||||
if (scaChatParams.getMaxTokens() != null) {
|
||||
mergedTongYiParams.setMaxTokens(scaChatParams.getMaxTokens());
|
||||
}
|
||||
|
||||
if (scaChatParams.getStop() != null) {
|
||||
mergedTongYiParams.setStopStrings(scaChatParams.getStop());
|
||||
}
|
||||
|
||||
if (scaChatParams.getTemperature() != null) {
|
||||
mergedTongYiParams.setTemperature(scaChatParams.getTemperature());
|
||||
}
|
||||
|
||||
if (scaChatParams.getTopK() != null) {
|
||||
mergedTongYiParams.setTopK(scaChatParams.getTopK());
|
||||
}
|
||||
|
||||
if (scaChatParams.getTopK() != null) {
|
||||
mergedTongYiParams.setTopK(scaChatParams.getTopK());
|
||||
}
|
||||
|
||||
return mergedTongYiParams;
|
||||
}
|
||||
|
||||
private com.alibaba.dashscope.common.Message fromSpringAIMessage(Message message) {
|
||||
|
||||
return switch (message.getMessageType()) {
|
||||
case USER -> com.alibaba.dashscope.common.Message.builder()
|
||||
.role(Role.USER.getValue())
|
||||
.content(message.getContent())
|
||||
.build();
|
||||
case SYSTEM -> com.alibaba.dashscope.common.Message.builder()
|
||||
.role(Role.SYSTEM.getValue())
|
||||
.content(message.getContent())
|
||||
.build();
|
||||
case ASSISTANT -> com.alibaba.dashscope.common.Message.builder()
|
||||
.role(Role.ASSISTANT.getValue())
|
||||
.content(message.getContent())
|
||||
.build();
|
||||
default -> throw new IllegalArgumentException("Unknown message type " + message.getMessageType());
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ConversationParam doCreateToolResponseRequest(
|
||||
ConversationParam previousRequest,
|
||||
com.alibaba.dashscope.common.Message responseMessage,
|
||||
List<com.alibaba.dashscope.common.Message> conversationHistory
|
||||
) {
|
||||
for (ToolCallBase toolCall : responseMessage.getToolCalls()) {
|
||||
if (toolCall instanceof ToolCallFunction toolCallFunction) {
|
||||
if (toolCallFunction.getFunction() != null) {
|
||||
var functionName = toolCallFunction.getFunction().getName();
|
||||
var functionArguments = toolCallFunction.getFunction().getArguments();
|
||||
|
||||
if (!this.functionCallbackRegister.containsKey(functionName)) {
|
||||
throw new IllegalStateException("No function callback found for function name: " + functionName);
|
||||
}
|
||||
|
||||
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
|
||||
|
||||
// Add the function response to the conversation.
|
||||
conversationHistory
|
||||
.add(com.alibaba.dashscope.common.Message.builder()
|
||||
.content(functionResponse)
|
||||
.role(Role.BOT.getValue())
|
||||
.toolCallId(toolCall.getId())
|
||||
.build()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
ConversationParam newRequest = ConversationParam.builder().messages(conversationHistory).build();
|
||||
|
||||
// todo: No @JsonProperty fields.
|
||||
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ConversationParam.class);
|
||||
|
||||
return newRequest;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<com.alibaba.dashscope.common.Message> doGetUserMessages(ConversationParam request) {
|
||||
|
||||
return request.getMessages();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected com.alibaba.dashscope.common.Message doGetToolResponseMessage(GenerationResult response) {
|
||||
|
||||
var message = response.getOutput().getChoices().get(0).getMessage();
|
||||
var assistantMessage = com.alibaba.dashscope.common.Message.builder().role(Role.ASSISTANT.getValue())
|
||||
.content("").build();
|
||||
assistantMessage.setToolCalls(message.getToolCalls());
|
||||
|
||||
return assistantMessage;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected GenerationResult doChatCompletion(ConversationParam request) {
|
||||
|
||||
GenerationResult result;
|
||||
try {
|
||||
result = generation.call(request);
|
||||
}
|
||||
catch (NoApiKeyException | InputRequiredException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Flux<GenerationResult> doChatCompletionStream(ConversationParam request) {
|
||||
final Flowable<GenerationResult> genRes;
|
||||
try {
|
||||
genRes = generation.streamCall(request);
|
||||
}
|
||||
catch (NoApiKeyException | InputRequiredException e) {
|
||||
logger.warn("TongYi chat client: " + e.getMessage());
|
||||
throw new TongYiException(e.getMessage());
|
||||
}
|
||||
return Flux.from(genRes);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean isToolFunctionCall(GenerationResult response) {
|
||||
|
||||
if (response == null || CollectionUtils.isEmpty(response.getOutput().getChoices())) {
|
||||
|
||||
return false;
|
||||
}
|
||||
var choice = response.getOutput().getChoices().get(0);
|
||||
if (choice == null || choice.getFinishReason() == null) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return Objects.equals(choice.getFinishReason(), ApiKeywords.TOOL_CALLS);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,463 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.chat;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationParam;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiChatOptions implements FunctionCallingOptions, ChatOptions {
|
||||
|
||||
/**
|
||||
* TongYi Models.
|
||||
* {@link Generation.Models}
|
||||
*/
|
||||
private String model = Generation.Models.QWEN_TURBO;
|
||||
|
||||
/**
|
||||
* The random number seed used in generation, the user controls the randomness of the content generated by the model.
|
||||
* seed supports unsigned 64-bit integers, with a default value of 1234.
|
||||
* when using seed, the model will generate the same or similar results as much as possible, but there is currently no guarantee that the results will be exactly the same each time.
|
||||
*/
|
||||
private Integer seed = 1234;
|
||||
|
||||
/**
|
||||
* Used to specify the maximum number of tokens that the model can generate when generating content,
|
||||
* it defines the upper limit of generation but does not guarantee that this number will be generated every time.
|
||||
* For qwen-turbo the maximum and default values are 1500 tokens.
|
||||
* The qwen-max, qwen-max-1201, qwen-max-longcontext, and qwen-plus models have a maximum and default value of 2000 tokens.
|
||||
*/
|
||||
private Integer maxTokens = 1500;
|
||||
|
||||
/**
|
||||
* The generation process kernel sampling method probability threshold,
|
||||
* for example, takes the value of 0.8, only retains the smallest set of the most probable tokens with probabilities that add up to greater than or equal to 0.8 as the candidate set.
|
||||
* The range of values is (0,1.0), the larger the value, the higher the randomness of generation; the lower the value, the higher the certainty of generation.
|
||||
*/
|
||||
private Double topP = 0.8;
|
||||
|
||||
/**
|
||||
* The size of the sampling candidate set at the time of generation.
|
||||
* For example, with a value of 50, only the 50 highest scoring tokens in a single generation will form a randomly sampled candidate set.
|
||||
* The larger the value, the higher the randomness of the generation; the smaller the value, the higher the certainty of the generation.
|
||||
* This parameter is not passed by default, and a value of None or when top_k is greater than 100 indicates that the top_k policy is not enabled,
|
||||
* at which time, only the top_p policy is in effect.
|
||||
*/
|
||||
private Integer topK;
|
||||
|
||||
/**
|
||||
* Used to control the repeatability of model generation.
|
||||
* Increasing repetition_penalty reduces the repetition of model generation. 1.0 means no penalty.
|
||||
*/
|
||||
private Double repetitionPenalty = 1.1;
|
||||
|
||||
/**
|
||||
* is used to control the degree of randomness and diversity.
|
||||
* Specifically, the temperature value controls the extent to which the probability distribution of each candidate word is smoothed when generating text.
|
||||
* Higher values of temperature reduce the peak of the probability distribution, allowing more low-probability words to be selected and generating more diverse results,
|
||||
* while lower values of temperature enhance the peak of the probability distribution, making it easier for high-probability words to be selected and generating more certain results.
|
||||
* Range: [0, 2), 0 is not recommended, meaningless.
|
||||
* java version >= 2.5.1
|
||||
*/
|
||||
private Double temperature = 0.85;
|
||||
|
||||
/**
|
||||
* The stop parameter is used to realize precise control of the content generation process, automatically stopping when the generated content is about to contain the specified string or token_ids,
|
||||
* and the generated content does not contain the specified content.
|
||||
* For example, if stop is specified as "Hello", it means stop when "Hello" will be generated; if stop is specified as [37763, 367], it means stop when "Observation" will be generated.
|
||||
* The stop parameter can be passed as a list of arrays of strings or token_ids to support the scenario of using multiple stops.
|
||||
* Explanation: Do not mix strings and token_ids in list mode, the element types should be the same in list mode.
|
||||
*/
|
||||
private List<String> stop;
|
||||
|
||||
/**
|
||||
* Whether or not to use stream output. When outputting the result in stream mode, the interface returns the result as generator,
|
||||
* you need to iterate to get the result, the default output is the whole sequence of the current generation for each output,
|
||||
* the last output is the final result of all the generation, you can change the output mode to non-incremental output by the parameter incremental_output to False.
|
||||
*/
|
||||
private Boolean stream = false;
|
||||
|
||||
/**
|
||||
* The model has a built-in Internet search service.
|
||||
* This parameter controls whether the model refers to the use of Internet search results when generating text. The values are as follows:
|
||||
* True: enable internet search, the model will use the search result as the reference information in the text generation process, but the model will "judge by itself" whether to use the internet search result based on its internal logic.
|
||||
* False (default): Internet search is disabled.
|
||||
*/
|
||||
private Boolean enableSearch = false;
|
||||
|
||||
/**
|
||||
* [text|message], defaults to text, when it is message,
|
||||
* the output refers to the message result example.
|
||||
* It is recommended to prioritize the use of message format.
|
||||
*/
|
||||
private String resultFormat = GenerationParam.ResultFormat.MESSAGE;
|
||||
|
||||
/**
|
||||
* Control the streaming output mode, that is, the content will contain the content has been output;
|
||||
* set to True, will open the incremental output mode, the output will not contain the content has been output,
|
||||
* you need to splice the whole output, refer to the streaming output sample code.
|
||||
*/
|
||||
private Boolean incrementalOutput = false;
|
||||
|
||||
/**
|
||||
* A list of tools that the model can optionally call.
|
||||
* Currently only functions are supported, and even if multiple functions are entered, the model will only select one to generate the result.
|
||||
*/
|
||||
private List<String> tools;
|
||||
|
||||
@Override
|
||||
public Float getTemperature() {
|
||||
|
||||
return this.temperature.floatValue();
|
||||
}
|
||||
|
||||
public void setTemperature(Float temperature) {
|
||||
|
||||
this.temperature = temperature.doubleValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Float getTopP() {
|
||||
|
||||
return this.topP.floatValue();
|
||||
}
|
||||
|
||||
public void setTopP(Float topP) {
|
||||
|
||||
this.topP = topP.doubleValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTopK() {
|
||||
|
||||
return this.topK;
|
||||
}
|
||||
|
||||
public void setTopK(Integer topK) {
|
||||
|
||||
this.topK = topK;
|
||||
}
|
||||
|
||||
public String getModel() {
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
public void setModel(String model) {
|
||||
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public Integer getSeed() {
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
public String getResultFormat() {
|
||||
|
||||
return resultFormat;
|
||||
}
|
||||
|
||||
public void setResultFormat(String resultFormat) {
|
||||
|
||||
this.resultFormat = resultFormat;
|
||||
}
|
||||
|
||||
public void setSeed(Integer seed) {
|
||||
|
||||
this.seed = seed;
|
||||
}
|
||||
|
||||
public Integer getMaxTokens() {
|
||||
|
||||
return maxTokens;
|
||||
}
|
||||
|
||||
public void setMaxTokens(Integer maxTokens) {
|
||||
|
||||
this.maxTokens = maxTokens;
|
||||
}
|
||||
|
||||
public Float getRepetitionPenalty() {
|
||||
|
||||
return repetitionPenalty.floatValue();
|
||||
}
|
||||
|
||||
public void setRepetitionPenalty(Float repetitionPenalty) {
|
||||
|
||||
this.repetitionPenalty = repetitionPenalty.doubleValue();
|
||||
}
|
||||
|
||||
public List<String> getStop() {
|
||||
|
||||
return stop;
|
||||
}
|
||||
|
||||
public void setStop(List<String> stop) {
|
||||
|
||||
this.stop = stop;
|
||||
}
|
||||
|
||||
public Boolean getStream() {
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
public void setStream(Boolean stream) {
|
||||
|
||||
this.stream = stream;
|
||||
}
|
||||
|
||||
public Boolean getEnableSearch() {
|
||||
|
||||
return enableSearch;
|
||||
}
|
||||
|
||||
public void setEnableSearch(Boolean enableSearch) {
|
||||
|
||||
this.enableSearch = enableSearch;
|
||||
}
|
||||
|
||||
public Boolean getIncrementalOutput() {
|
||||
|
||||
return incrementalOutput;
|
||||
}
|
||||
|
||||
public void setIncrementalOutput(Boolean incrementalOutput) {
|
||||
|
||||
this.incrementalOutput = incrementalOutput;
|
||||
}
|
||||
|
||||
public List<String> getTools() {
|
||||
|
||||
return tools;
|
||||
}
|
||||
|
||||
public void setTools(List<String> tools) {
|
||||
|
||||
this.tools = tools;
|
||||
}
|
||||
|
||||
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
|
||||
private Set<String> functions = new HashSet<>();
|
||||
|
||||
@Override
|
||||
public List<FunctionCallback> getFunctionCallbacks() {
|
||||
|
||||
return this.functionCallbacks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||
|
||||
this.functionCallbacks = functionCallbacks;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getFunctions() {
|
||||
|
||||
return this.functions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFunctions(Set<String> functions) {
|
||||
|
||||
this.functions = functions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TongYiChatOptions that = (TongYiChatOptions) o;
|
||||
|
||||
return Objects.equals(model, that.model)
|
||||
&& Objects.equals(seed, that.seed)
|
||||
&& Objects.equals(maxTokens, that.maxTokens)
|
||||
&& Objects.equals(topP, that.topP)
|
||||
&& Objects.equals(topK, that.topK)
|
||||
&& Objects.equals(repetitionPenalty, that.repetitionPenalty)
|
||||
&& Objects.equals(temperature, that.temperature)
|
||||
&& Objects.equals(stop, that.stop)
|
||||
&& Objects.equals(stream, that.stream)
|
||||
&& Objects.equals(enableSearch, that.enableSearch)
|
||||
&& Objects.equals(resultFormat, that.resultFormat)
|
||||
&& Objects.equals(incrementalOutput, that.incrementalOutput)
|
||||
&& Objects.equals(tools, that.tools)
|
||||
&& Objects.equals(functionCallbacks, that.functionCallbacks)
|
||||
&& Objects.equals(functions, that.functions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(
|
||||
model,
|
||||
seed,
|
||||
maxTokens,
|
||||
topP,
|
||||
topK,
|
||||
repetitionPenalty,
|
||||
temperature,
|
||||
stop,
|
||||
stream,
|
||||
enableSearch,
|
||||
resultFormat,
|
||||
incrementalOutput,
|
||||
tools,
|
||||
functionCallbacks,
|
||||
functions
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
final StringBuilder sb = new StringBuilder("TongYiChatOptions{");
|
||||
|
||||
sb.append(", model='").append(model).append('\'');
|
||||
sb.append(", seed=").append(seed);
|
||||
sb.append(", maxTokens=").append(maxTokens);
|
||||
sb.append(", topP=").append(topP);
|
||||
sb.append(", topK=").append(topK);
|
||||
sb.append(", repetitionPenalty=").append(repetitionPenalty);
|
||||
sb.append(", temperature=").append(temperature);
|
||||
sb.append(", stop=").append(stop);
|
||||
sb.append(", stream=").append(stream);
|
||||
sb.append(", enableSearch=").append(enableSearch);
|
||||
sb.append(", resultFormat='").append(resultFormat).append('\'');
|
||||
sb.append(", incrementalOutput=").append(incrementalOutput);
|
||||
sb.append(", tools=").append(tools);
|
||||
sb.append(", functionCallbacks=").append(functionCallbacks);
|
||||
sb.append(", functions=").append(functions);
|
||||
sb.append('}');
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
protected TongYiChatOptions options;
|
||||
|
||||
public Builder() {
|
||||
|
||||
this.options = new TongYiChatOptions();
|
||||
}
|
||||
|
||||
public Builder(TongYiChatOptions options) {
|
||||
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public Builder withModel(String model) {
|
||||
this.options.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMaxTokens(Integer maxTokens) {
|
||||
this.options.maxTokens = maxTokens;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withResultFormat(String rf) {
|
||||
this.options.resultFormat = rf;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withEnableSearch(Boolean enableSearch) {
|
||||
this.options.enableSearch = enableSearch;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
|
||||
this.options.functionCallbacks = functionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withFunctions(Set<String> functionNames) {
|
||||
Assert.notNull(functionNames, "Function names must not be null");
|
||||
this.options.functions = functionNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withFunction(String functionName) {
|
||||
Assert.hasText(functionName, "Function name must not be empty");
|
||||
this.options.functions.add(functionName);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withSeed(Integer seed) {
|
||||
this.options.seed = seed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withStop(List<String> stop) {
|
||||
this.options.stop = stop;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTemperature(Double temperature) {
|
||||
this.options.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTopP(Double topP) {
|
||||
this.options.topP = topP;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withTopK(Integer topK) {
|
||||
this.options.topK = topK;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withRepetitionPenalty(Double repetitionPenalty) {
|
||||
this.options.repetitionPenalty = repetitionPenalty;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiChatOptions build() {
|
||||
|
||||
return this.options;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.chat;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationParam;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiChatProperties.CONFIG_PREFIX)
|
||||
public class TongYiChatProperties {
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI configuration prefix.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "chat";
|
||||
|
||||
/**
|
||||
* Default TongYi Chat model.
|
||||
*/
|
||||
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_TURBO;
|
||||
|
||||
/**
|
||||
* Default temperature speed.
|
||||
*/
|
||||
private static final Double DEFAULT_TEMPERATURE = 0.8;
|
||||
|
||||
/**
|
||||
* Enable TongYiQWEN ai chat client.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private TongYiChatOptions options = TongYiChatOptions.builder()
|
||||
.withModel(DEFAULT_DEPLOYMENT_NAME)
|
||||
.withTemperature(DEFAULT_TEMPERATURE)
|
||||
.withEnableSearch(true)
|
||||
.withResultFormat(GenerationParam.ResultFormat.MESSAGE)
|
||||
.build();
|
||||
|
||||
public TongYiChatOptions getOptions() {
|
||||
|
||||
return this.options;
|
||||
}
|
||||
|
||||
public void setOptions(TongYiChatOptions options) {
|
||||
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright 2024-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.common.constants;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
*/
|
||||
|
||||
public final class TongYiConstants {
|
||||
|
||||
private TongYiConstants() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI configuration prefix.
|
||||
*/
|
||||
public static final String SCA_AI_CONFIGURATION = "spring.cloud.ai.tongyi.";
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI constants prefix.
|
||||
*/
|
||||
public static final String SCA_AI = "SPRING_CLOUD_ALIBABA_";
|
||||
|
||||
/**
|
||||
* TongYi AI apikey env name.
|
||||
*/
|
||||
public static final String SCA_AI_TONGYI_API_KEY = SCA_AI + "TONGYI_API_KEY";
|
||||
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.common.exception;
|
||||
|
||||
/**
|
||||
* TongYi models runtime exception.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiException extends RuntimeException {
|
||||
|
||||
public TongYiException(String message) {
|
||||
|
||||
super(message);
|
||||
}
|
||||
|
||||
public TongYiException(String message, Throwable cause) {
|
||||
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.common.exception;
|
||||
|
||||
/**
|
||||
* TongYi models images exception.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiImagesException extends TongYiException {
|
||||
|
||||
public TongYiImagesException(String message) {
|
||||
|
||||
super(message);
|
||||
}
|
||||
|
||||
public TongYiImagesException(String message, Throwable cause) {
|
||||
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.embedding;
|
||||
|
||||
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author why_ohh
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public final class TongYiEmbeddingOptions implements EmbeddingOptions {
|
||||
|
||||
private List<String> texts;
|
||||
|
||||
private TextEmbeddingParam.TextType textType;
|
||||
|
||||
public List<String> getTexts() {
|
||||
return texts;
|
||||
}
|
||||
|
||||
public void setTexts(List<String> texts) {
|
||||
this.texts = texts;
|
||||
}
|
||||
|
||||
public TextEmbeddingParam.TextType getTextType() {
|
||||
return textType;
|
||||
}
|
||||
|
||||
public void setTextType(TextEmbeddingParam.TextType textType) {
|
||||
this.textType = textType;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public final static class Builder {
|
||||
|
||||
private final TongYiEmbeddingOptions options;
|
||||
|
||||
private Builder() {
|
||||
this.options = new TongYiEmbeddingOptions();
|
||||
}
|
||||
|
||||
public Builder withtexts(List<String> texts) {
|
||||
|
||||
options.setTexts(texts);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withtextType(TextEmbeddingParam.TextType textType) {
|
||||
|
||||
options.setTextType(textType);
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiEmbeddingOptions build() {
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.embedding;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
|
||||
import com.alibaba.cloud.ai.tongyi.metadata.TongYiTextEmbeddingResponseMetadata;
|
||||
import com.alibaba.dashscope.embeddings.TextEmbedding;
|
||||
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
|
||||
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
|
||||
import com.alibaba.dashscope.embeddings.TextEmbeddingResultItem;
|
||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* {@link TongYiTextEmbeddingModel} implementation for {@literal Alibaba DashScope}.
|
||||
*
|
||||
* @author why_ohh
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiTextEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(TongYiTextEmbeddingModel.class);
|
||||
|
||||
/**
|
||||
* TongYi Text Embedding client.
|
||||
*/
|
||||
private final TextEmbedding textEmbedding;
|
||||
|
||||
/**
|
||||
* {@link MetadataMode}.
|
||||
*/
|
||||
private final MetadataMode metadataMode;
|
||||
|
||||
private final TongYiEmbeddingOptions defaultOptions;
|
||||
|
||||
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding) {
|
||||
|
||||
this(textEmbedding, MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding, MetadataMode metadataMode) {
|
||||
|
||||
this(textEmbedding, metadataMode,
|
||||
TongYiEmbeddingOptions.builder()
|
||||
.withtextType(TextEmbeddingParam.TextType.DOCUMENT)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
public TongYiTextEmbeddingModel(
|
||||
TextEmbedding textEmbedding,
|
||||
MetadataMode metadataMode,
|
||||
TongYiEmbeddingOptions options
|
||||
) {
|
||||
Assert.notNull(textEmbedding, "textEmbedding must not be null");
|
||||
Assert.notNull(metadataMode, "Metadata mode must not be null");
|
||||
Assert.notNull(options, "TongYiEmbeddingOptions must not be null");
|
||||
|
||||
this.metadataMode = metadataMode;
|
||||
this.textEmbedding = textEmbedding;
|
||||
this.defaultOptions = options;
|
||||
}
|
||||
|
||||
public TongYiEmbeddingOptions getDefaultOptions() {
|
||||
|
||||
return this.defaultOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> embed(Document document) {
|
||||
|
||||
return this.call(
|
||||
new EmbeddingRequest(
|
||||
List.of(document.getFormattedContent(this.metadataMode)),
|
||||
null)
|
||||
).getResults().stream()
|
||||
.map(Embedding::getOutput)
|
||||
.flatMap(List::stream)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingResponse call(EmbeddingRequest request) {
|
||||
|
||||
TextEmbeddingParam embeddingParams = toEmbeddingParams(request);
|
||||
logger.debug("Embedding request: {}", embeddingParams);
|
||||
TextEmbeddingResult resp;
|
||||
|
||||
try {
|
||||
resp = textEmbedding.call(embeddingParams);
|
||||
}
|
||||
catch (NoApiKeyException e) {
|
||||
throw new TongYiException(e.getMessage());
|
||||
}
|
||||
|
||||
return genEmbeddingResp(resp);
|
||||
}
|
||||
|
||||
private EmbeddingResponse genEmbeddingResp(TextEmbeddingResult result) {
|
||||
|
||||
return new EmbeddingResponse(
|
||||
genEmbeddingList(result.getOutput().getEmbeddings()),
|
||||
TongYiTextEmbeddingResponseMetadata.from(result.getUsage())
|
||||
);
|
||||
}
|
||||
|
||||
private List<Embedding> genEmbeddingList(List<TextEmbeddingResultItem> embeddings) {
|
||||
|
||||
return embeddings.stream()
|
||||
.map(embedding ->
|
||||
new Embedding(
|
||||
embedding.getEmbedding(),
|
||||
embedding.getTextIndex()
|
||||
))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* We recommend setting the model parameters by passing the embedding parameters through the code;
|
||||
* yml configuration compatibility is not considered here.
|
||||
* It is not recommended that users set parameters from yml,
|
||||
* as this reduces the flexibility of the configuration.
|
||||
* Because the model name keeps changing, strings are used here and constants are undefined:
|
||||
* Model list: <a href="https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-quick-start">Text Embedding Model List</a>
|
||||
* @param requestOptions Client params. {@link EmbeddingRequest}
|
||||
* @return {@link TextEmbeddingParam}
|
||||
*/
|
||||
private TextEmbeddingParam toEmbeddingParams(EmbeddingRequest requestOptions) {
|
||||
|
||||
TextEmbeddingParam tongYiEmbeddingParams = TextEmbeddingParam.builder()
|
||||
.texts(requestOptions.getInstructions())
|
||||
.textType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : TextEmbeddingParam.TextType.DOCUMENT)
|
||||
.model("text-embedding-v1")
|
||||
.build();
|
||||
|
||||
try {
|
||||
tongYiEmbeddingParams.validate();
|
||||
}
|
||||
catch (InputRequiredException e) {
|
||||
throw new TongYiException(e.getMessage());
|
||||
}
|
||||
|
||||
return tongYiEmbeddingParams;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.embedding;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* @author why_ohh
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiTextEmbeddingProperties.CONFIG_PREFIX)
|
||||
public class TongYiTextEmbeddingProperties {
|
||||
|
||||
/**
|
||||
* Prefix of TongYi Text Embedding properties.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "embedding";
|
||||
|
||||
private boolean enabled = true;
|
||||
|
||||
public boolean isEnabled() {
|
||||
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.image;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.image.*;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.net.URL;
|
||||
import java.util.Base64;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.metadata.TongYiImagesResponseMetadata.from;
|
||||
|
||||
/**
|
||||
* TongYiImagesClient is a class that implements the ImageClient interface. It provides a
|
||||
* client for calling the TongYi AI image generation API.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiImagesModel implements ImageModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(TongYiImagesModel.class);
|
||||
|
||||
/**
|
||||
* Gen Images API.
|
||||
*/
|
||||
private final ImageSynthesis imageSynthesis;
|
||||
|
||||
/**
|
||||
* TongYi Gen images properties.
|
||||
*/
|
||||
private TongYiImagesOptions defaultOptions;
|
||||
|
||||
/**
|
||||
* Adapt TongYi images api size properties.
|
||||
*/
|
||||
private final String sizeConnection = "*";
|
||||
|
||||
/**
|
||||
* Get default images options.
|
||||
*
|
||||
* @return Default TongYiImagesOptions.
|
||||
*/
|
||||
public TongYiImagesOptions getDefaultOptions() {
|
||||
|
||||
return this.defaultOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* TongYiImagesClient constructor.
|
||||
* @param imageSynthesis the image synthesis
|
||||
* {@link ImageSynthesis}
|
||||
*/
|
||||
public TongYiImagesModel(ImageSynthesis imageSynthesis) {
|
||||
|
||||
this(imageSynthesis, TongYiImagesOptions.
|
||||
builder()
|
||||
.withModel(ImageSynthesis.Models.WANX_V1)
|
||||
.withN(1)
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* TongYiImagesClient constructor.
|
||||
* @param imageSynthesis {@link ImageSynthesis}
|
||||
* @param imagesOptions {@link TongYiImagesOptions}
|
||||
*/
|
||||
public TongYiImagesModel(ImageSynthesis imageSynthesis, TongYiImagesOptions imagesOptions) {
|
||||
|
||||
Assert.notNull(imageSynthesis, "ImageSynthesis must not be null");
|
||||
Assert.notNull(imagesOptions, "TongYiImagesOptions must not be null");
|
||||
|
||||
this.imageSynthesis = imageSynthesis;
|
||||
this.defaultOptions = imagesOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Call the TongYi images service.
|
||||
* @param imagePrompt the image prompt.
|
||||
* @return the image response.
|
||||
* {@link ImageSynthesis#call(ImageSynthesisParam)}
|
||||
*/
|
||||
@Override
|
||||
public ImageResponse call(ImagePrompt imagePrompt) {
|
||||
|
||||
ImageSynthesisResult result;
|
||||
String prompt = imagePrompt.getInstructions().get(0).getText();
|
||||
var imgParams = ImageSynthesisParam.builder()
|
||||
.prompt("")
|
||||
.model(ImageSynthesis.Models.WANX_V1)
|
||||
.build();
|
||||
|
||||
if (this.defaultOptions != null) {
|
||||
|
||||
imgParams = merge(this.defaultOptions);
|
||||
}
|
||||
|
||||
if (imagePrompt.getOptions() != null) {
|
||||
|
||||
imgParams = merge(toTingYiImageOptions(imagePrompt.getOptions()));
|
||||
}
|
||||
imgParams.setPrompt(prompt);
|
||||
|
||||
try {
|
||||
result = imageSynthesis.call(imgParams);
|
||||
}
|
||||
catch (NoApiKeyException e) {
|
||||
|
||||
logger.error("TongYi models gen images failed: {}.", e.getMessage());
|
||||
throw new TongYiImagesException(e.getMessage());
|
||||
}
|
||||
|
||||
return convert(result);
|
||||
}
|
||||
|
||||
public ImageSynthesisParam merge(TongYiImagesOptions target) {
|
||||
|
||||
var builder = ImageSynthesisParam.builder();
|
||||
|
||||
builder.model(this.defaultOptions.getModel() != null ? this.defaultOptions.getModel() : target.getModel());
|
||||
builder.n(this.defaultOptions.getN() != null ? this.defaultOptions.getN() : target.getN());
|
||||
builder.size((this.defaultOptions.getHeight() != null && this.defaultOptions.getWidth() != null)
|
||||
? this.defaultOptions.getHeight() + "*" + this.defaultOptions.getWidth()
|
||||
: target.getHeight() + "*" + target.getWidth()
|
||||
);
|
||||
|
||||
// prompt is marked non-null but is null.
|
||||
builder.prompt("");
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private ImageResponse convert(ImageSynthesisResult result) {
|
||||
|
||||
return new ImageResponse(
|
||||
result.getOutput().getResults().stream()
|
||||
.flatMap(value -> value.entrySet().stream())
|
||||
.map(entry -> {
|
||||
String key = entry.getKey();
|
||||
String value = entry.getValue();
|
||||
try {
|
||||
String base64Image = convertImageToBase64(value);
|
||||
Image image = new Image(value, base64Image);
|
||||
return new ImageGeneration(image);
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
})
|
||||
.collect(Collectors.toList()),
|
||||
from(result)
|
||||
);
|
||||
}
|
||||
|
||||
public TongYiImagesOptions toTingYiImageOptions(ImageOptions runtimeImageOptions) {
|
||||
|
||||
var builder = TongYiImagesOptions.builder();
|
||||
|
||||
if (runtimeImageOptions != null) {
|
||||
if (runtimeImageOptions.getN() != null) {
|
||||
|
||||
builder.withN(runtimeImageOptions.getN());
|
||||
}
|
||||
if (runtimeImageOptions.getModel() != null) {
|
||||
|
||||
builder.withModel(runtimeImageOptions.getModel());
|
||||
}
|
||||
if (runtimeImageOptions.getHeight() != null) {
|
||||
|
||||
builder.withHeight(runtimeImageOptions.getHeight());
|
||||
}
|
||||
if (runtimeImageOptions.getWidth() != null) {
|
||||
|
||||
builder.withWidth(runtimeImageOptions.getWidth());
|
||||
}
|
||||
|
||||
// todo ImagesParams.
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert image to base64.
|
||||
* @param imageUrl the image url.
|
||||
* @return the base64 image.
|
||||
* @throws Exception the exception.
|
||||
*/
|
||||
public String convertImageToBase64(String imageUrl) throws Exception {
|
||||
|
||||
var url = new URL(imageUrl);
|
||||
var inputStream = url.openStream();
|
||||
var outputStream = new ByteArrayOutputStream();
|
||||
var buffer = new byte[4096];
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = inputStream.read(buffer)) != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead);
|
||||
}
|
||||
|
||||
var imageBytes = outputStream.toByteArray();
|
||||
|
||||
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||
|
||||
inputStream.close();
|
||||
outputStream.close();
|
||||
|
||||
return base64Image;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.image;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* TongYi Image API options.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiImagesOptions implements ImageOptions {
|
||||
|
||||
/**
|
||||
* Specify the model name, currently only wanx-v1 is supported.
|
||||
*/
|
||||
private String model = ImageSynthesis.Models.WANX_V1;
|
||||
|
||||
/**
|
||||
* Gen images number.
|
||||
*/
|
||||
private Integer n;
|
||||
|
||||
/**
|
||||
* The width of the generated images.
|
||||
*/
|
||||
private Integer width = 1024;
|
||||
|
||||
/**
|
||||
* The height of the generated images.
|
||||
*/
|
||||
private Integer height = 1024;
|
||||
|
||||
@Override
|
||||
public Integer getN() {
|
||||
|
||||
return this.n;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
|
||||
return this.model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getWidth() {
|
||||
|
||||
return this.width;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getHeight() {
|
||||
|
||||
return this.height;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResponseFormat() {
|
||||
|
||||
throw new TongYiImagesException("unimplemented!");
|
||||
}
|
||||
|
||||
public void setModel(String model) {
|
||||
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public void setN(Integer n) {
|
||||
|
||||
this.n = n;
|
||||
}
|
||||
|
||||
public void setWidth(Integer width) {
|
||||
|
||||
this.width = width;
|
||||
}
|
||||
|
||||
public void setHeight(Integer height) {
|
||||
|
||||
this.height = height;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
TongYiImagesOptions that = (TongYiImagesOptions) o;
|
||||
|
||||
return Objects.equals(model, that.model)
|
||||
&& Objects.equals(n, that.n)
|
||||
&& Objects.equals(width, that.width)
|
||||
&& Objects.equals(height, that.height);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
|
||||
return Objects.hash(model, n, width, height);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
final StringBuilder sb = new StringBuilder("TongYiImagesOptions{");
|
||||
|
||||
sb.append("model='").append(model).append('\'');
|
||||
sb.append(", n=").append(n);
|
||||
sb.append(", width=").append(width);
|
||||
sb.append(", height=").append(height);
|
||||
sb.append('}');
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public final static class Builder {
|
||||
|
||||
private final TongYiImagesOptions options;
|
||||
|
||||
private Builder() {
|
||||
this.options = new TongYiImagesOptions();
|
||||
}
|
||||
|
||||
public Builder withN(Integer n) {
|
||||
|
||||
options.setN(n);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withModel(String model) {
|
||||
|
||||
options.setModel(model);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withWidth(Integer width) {
|
||||
|
||||
options.setWidth(width);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withHeight(Integer height) {
|
||||
|
||||
options.setHeight(height);
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiImagesOptions build() {
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.image;
|
||||
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
|
||||
|
||||
/**
|
||||
* TongYi Image API properties.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
@ConfigurationProperties(TongYiImagesProperties.CONFIG_PREFIX)
|
||||
public class TongYiImagesProperties {
|
||||
|
||||
/**
|
||||
* Spring Cloud Alibaba AI configuration prefix.
|
||||
*/
|
||||
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "images";
|
||||
|
||||
/**
|
||||
* Default TongYi Chat model.
|
||||
*/
|
||||
public static final String DEFAULT_IMAGES_MODEL_NAME = ImageSynthesis.Models.WANX_V1;
|
||||
|
||||
/**
|
||||
* Enable TongYiQWEN ai images client.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
private TongYiImagesOptions options = TongYiImagesOptions.builder()
|
||||
.withModel(DEFAULT_IMAGES_MODEL_NAME)
|
||||
.withN(1)
|
||||
.build();
|
||||
|
||||
public TongYiImagesOptions getOptions() {
|
||||
|
||||
return this.options;
|
||||
}
|
||||
|
||||
public void setOptions(TongYiImagesOptions options) {
|
||||
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for {@literal Alibaba DashScope}.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
|
||||
|
||||
@SuppressWarnings("all")
|
||||
public static TongYiAiChatResponseMetadata from(GenerationResult chatCompletions,
|
||||
PromptMetadata promptFilterMetadata) {
|
||||
|
||||
Assert.notNull(chatCompletions, "Alibaba ai ChatCompletions must not be null");
|
||||
String id = chatCompletions.getRequestId();
|
||||
TongYiAiUsage usage = TongYiAiUsage.from(chatCompletions);
|
||||
|
||||
return new TongYiAiChatResponseMetadata(
|
||||
id,
|
||||
usage,
|
||||
promptFilterMetadata
|
||||
);
|
||||
}
|
||||
|
||||
private final String id;
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
private final PromptMetadata promptMetadata;
|
||||
|
||||
protected TongYiAiChatResponseMetadata(String id, TongYiAiUsage usage, PromptMetadata promptMetadata) {
|
||||
|
||||
this.id = id;
|
||||
this.usage = usage;
|
||||
this.promptMetadata = promptMetadata;
|
||||
}
|
||||
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptMetadata getPromptMetadata() {
|
||||
return this.promptMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata;
|
||||
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationResult;
|
||||
import com.alibaba.dashscope.aigc.generation.GenerationUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* {@link Usage} implementation for {@literal Alibaba DashScope}.
|
||||
*
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAiUsage implements Usage {
|
||||
|
||||
private final GenerationUsage usage;
|
||||
|
||||
public TongYiAiUsage(GenerationUsage usage) {
|
||||
|
||||
Assert.notNull(usage, "GenerationUsage must not be null");
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
public static TongYiAiUsage from(GenerationResult chatCompletions) {
|
||||
|
||||
Assert.notNull(chatCompletions, "ChatCompletions must not be null");
|
||||
return from(chatCompletions.getUsage());
|
||||
}
|
||||
|
||||
public static TongYiAiUsage from(GenerationUsage usage) {
|
||||
|
||||
return new TongYiAiUsage(usage);
|
||||
}
|
||||
|
||||
protected GenerationUsage getUsage() {
|
||||
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getPromptTokens() {
|
||||
|
||||
throw new UnsupportedOperationException("Unimplemented method 'getPromptTokens'");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getGenerationTokens() {
|
||||
|
||||
return this.getUsage().getOutputTokens().longValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getTotalTokens() {
|
||||
|
||||
return this.getUsage().getTotalTokens().longValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
return this.getUsage().toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata;
|
||||
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisTaskMetrics;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisUsage;
|
||||
import org.springframework.ai.image.ImageResponseMetadata;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiImagesResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
|
||||
private final Long created;
|
||||
|
||||
private String taskId;
|
||||
|
||||
private ImageSynthesisTaskMetrics metrics;
|
||||
|
||||
private ImageSynthesisUsage usage;
|
||||
|
||||
public static TongYiImagesResponseMetadata from(ImageSynthesisResult synthesisResult) {
|
||||
|
||||
Assert.notNull(synthesisResult, "TongYiAiImageResponse must not be null");
|
||||
|
||||
return new TongYiImagesResponseMetadata(
|
||||
System.currentTimeMillis(),
|
||||
synthesisResult.getOutput().getTaskMetrics(),
|
||||
synthesisResult.getOutput().getTaskId(),
|
||||
synthesisResult.getUsage()
|
||||
);
|
||||
}
|
||||
|
||||
protected TongYiImagesResponseMetadata(
|
||||
Long created,
|
||||
ImageSynthesisTaskMetrics metrics,
|
||||
String taskId,
|
||||
ImageSynthesisUsage usage
|
||||
) {
|
||||
|
||||
this.taskId = taskId;
|
||||
this.metrics = metrics;
|
||||
this.created = created;
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
public ImageSynthesisUsage getUsage() {
|
||||
return usage;
|
||||
}
|
||||
|
||||
public void setUsage(ImageSynthesisUsage usage) {
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCreated() {
|
||||
return created;
|
||||
}
|
||||
|
||||
public String getTaskId() {
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void setTaskId(String taskId) {
|
||||
this.taskId = taskId;
|
||||
}
|
||||
|
||||
public ImageSynthesisTaskMetrics getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
void setMetrics(ImageSynthesisTaskMetrics metrics) {
|
||||
this.metrics = metrics;
|
||||
}
|
||||
|
||||
|
||||
public Long created() {
|
||||
return this.created;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TongYiImagesResponseMetadata {" + "created=" + created + '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
TongYiImagesResponseMetadata that = (TongYiImagesResponseMetadata) o;
|
||||
|
||||
return Objects.equals(created, that.created)
|
||||
&& Objects.equals(taskId, that.taskId)
|
||||
&& Objects.equals(metrics, that.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(created, taskId, metrics);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata;
|
||||
|
||||
import com.alibaba.dashscope.embeddings.TextEmbeddingUsage;
|
||||
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
|
||||
|
||||
/**
|
||||
* @author why_ohh
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiTextEmbeddingResponseMetadata extends EmbeddingResponseMetadata {
|
||||
|
||||
private Integer totalTokens;
|
||||
|
||||
protected TongYiTextEmbeddingResponseMetadata(Integer totalTokens) {
|
||||
|
||||
this.totalTokens = totalTokens;
|
||||
}
|
||||
|
||||
public static TongYiTextEmbeddingResponseMetadata from(TextEmbeddingUsage usage) {
|
||||
|
||||
return new TongYiTextEmbeddingResponseMetadata(usage.getTotalTokens());
|
||||
}
|
||||
|
||||
public Integer getTotalTokens() {
|
||||
|
||||
return totalTokens;
|
||||
}
|
||||
|
||||
public void setTotalTokens(Integer totalTokens) {
|
||||
|
||||
this.totalTokens = totalTokens;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata.audio;
|
||||
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
|
||||
import com.alibaba.dashscope.audio.tts.SpeechSynthesisUsage;
|
||||
import com.alibaba.dashscope.audio.tts.timestamp.Sentence;
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
|
||||
|
||||
private SpeechSynthesisUsage usage;
|
||||
|
||||
private String requestId;
|
||||
|
||||
private Sentence time;
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";
|
||||
|
||||
/**
|
||||
* NULL objects.
|
||||
*/
|
||||
public static final TongYiAudioSpeechResponseMetadata NULL = new TongYiAudioSpeechResponseMetadata() {
|
||||
};
|
||||
|
||||
public static TongYiAudioSpeechResponseMetadata from(SpeechSynthesisResult result) {
|
||||
|
||||
Assert.notNull(result, "TongYi AI speech must not be null");
|
||||
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
|
||||
|
||||
|
||||
|
||||
return speechResponseMetadata;
|
||||
}
|
||||
|
||||
public static TongYiAudioSpeechResponseMetadata from(String result) {
|
||||
|
||||
Assert.notNull(result, "TongYi AI speech must not be null");
|
||||
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
|
||||
|
||||
return speechResponseMetadata;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private RateLimit rateLimit;
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata() {
|
||||
|
||||
this(null);
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) {
|
||||
|
||||
this.rateLimit = rateLimit;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public RateLimit getRateLimit() {
|
||||
|
||||
RateLimit rateLimit = this.rateLimit;
|
||||
return rateLimit != null ? rateLimit : new EmptyRateLimit();
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata withRateLimit(RateLimit rateLimit) {
|
||||
|
||||
this.rateLimit = rateLimit;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata withUsage(SpeechSynthesisUsage usage) {
|
||||
|
||||
this.usage = usage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata withRequestId(String id) {
|
||||
|
||||
this.requestId = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TongYiAudioSpeechResponseMetadata withSentence(Sentence sentence) {
|
||||
|
||||
this.time = sentence;
|
||||
return this;
|
||||
}
|
||||
|
||||
public SpeechSynthesisUsage getUsage() {
|
||||
return usage;
|
||||
}
|
||||
|
||||
public String getRequestId() {
|
||||
return requestId;
|
||||
}
|
||||
|
||||
public Sentence getTime() {
|
||||
return time;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata.audio;
|
||||
|
||||
import org.springframework.ai.model.ResultMetadata;
|
||||
|
||||
/**
|
||||
* @author xYLiu
|
||||
* @author yuluo
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public interface TongYiAudioTranscriptionMetadata extends ResultMetadata {
|
||||
|
||||
/**
|
||||
* A constant instance of {@link TongYiAudioTranscriptionMetadata} that represents a null or empty metadata.
|
||||
*/
|
||||
TongYiAudioTranscriptionMetadata NULL = TongYiAudioTranscriptionMetadata.create();
|
||||
|
||||
/**
|
||||
* Factory method for creating a new instance of {@link TongYiAudioTranscriptionMetadata}.
|
||||
* @return a new instance of {@link TongYiAudioTranscriptionMetadata}
|
||||
*/
|
||||
static TongYiAudioTranscriptionMetadata create() {
|
||||
return new TongYiAudioTranscriptionMetadata() {
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.alibaba.cloud.ai.tongyi.metadata.audio;
|
||||
|
||||
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionResult;
|
||||
import com.google.gson.JsonObject;
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import java.util.HashMap;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @author yuluo
|
||||
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
|
||||
* @since 2023.0.1.0
|
||||
*/
|
||||
|
||||
public class TongYiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
|
||||
|
||||
@Nullable
|
||||
private RateLimit rateLimit;
|
||||
|
||||
private JsonObject usage;
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";
|
||||
|
||||
/**
|
||||
* NULL objects.
|
||||
*/
|
||||
public static final TongYiAudioTranscriptionResponseMetadata NULL = new TongYiAudioTranscriptionResponseMetadata() {
|
||||
};
|
||||
|
||||
protected TongYiAudioTranscriptionResponseMetadata() {
|
||||
|
||||
this(null, new JsonObject());
|
||||
}
|
||||
|
||||
protected TongYiAudioTranscriptionResponseMetadata(JsonObject usage) {
|
||||
|
||||
this(null, usage);
|
||||
}
|
||||
|
||||
protected TongYiAudioTranscriptionResponseMetadata(@Nullable RateLimit rateLimit, JsonObject usage) {
|
||||
|
||||
this.rateLimit = rateLimit;
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
public static TongYiAudioTranscriptionResponseMetadata from(TranscriptionResult result) {
|
||||
|
||||
Assert.notNull(result, "TongYi Transcription must not be null");
|
||||
return new TongYiAudioTranscriptionResponseMetadata(result.getUsage());
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public RateLimit getRateLimit() {
|
||||
|
||||
return this.rateLimit != null ? this.rateLimit : new EmptyRateLimit();
|
||||
}
|
||||
|
||||
public void setRateLimit(@Nullable RateLimit rateLimit) {
|
||||
this.rateLimit = rateLimit;
|
||||
}
|
||||
|
||||
public JsonObject getUsage() {
|
||||
return usage;
|
||||
}
|
||||
|
||||
public void setUsage(JsonObject usage) {
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
||||
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration
|
||||
@@ -0,0 +1,53 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link DeepSeekChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class DeepSeekChatModelTests {
|
||||
|
||||
private final DeepSeekChatModel chatModel = new DeepSeekChatModel("sk-e94db327cc7d457d99a8de8810fc6b12");
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link OllamaChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class LlamaChatModelTests {
|
||||
|
||||
private final OllamaApi ollamaApi = new OllamaApi(
|
||||
"http://127.0.0.1:11434");
|
||||
private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi,
|
||||
OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName()));
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link XingHuoChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class OpenAIChatModelTests {
|
||||
|
||||
private final OpenAiApi openAiApi = new OpenAiApi(
|
||||
"https://api.holdai.top",
|
||||
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf");
|
||||
private final OpenAiChatModel chatModel = new OpenAiChatModel(openAiApi,
|
||||
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O).build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.hutool.core.util.ReflectUtil;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import com.alibaba.dashscope.aigc.generation.Generation;
|
||||
import com.alibaba.dashscope.common.MessageManager;
|
||||
import com.alibaba.dashscope.utils.Constants;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link TongYiChatModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class TongYiChatModelTests {
|
||||
|
||||
private final Generation generation = new Generation();
|
||||
private final TongYiChatModel chatModel = new TongYiChatModel(generation,
|
||||
TongYiChatOptions.builder().withModel("qwen1.5-72b-chat").build());
|
||||
|
||||
static {
|
||||
Constants.apiKey = "sk-Zsd81gZYg7";
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void before() {
|
||||
// 防止 TongYiChatModel 调用空指针
|
||||
ReflectUtil.setFieldValue(chatModel, "msgManager", new MessageManager());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link XingHuoChatModel} 集成测试
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class XingHuoChatModelTests {
|
||||
|
||||
private final XingHuoChatModel chatModel = new XingHuoChatModel(
|
||||
"cb6415c19d6162cda07b47316fcb0416",
|
||||
"Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh");
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.qianfan.QianFanChatModel;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link QianFanChatModel} 的集成测试
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class YiYanChatModelTests {
|
||||
|
||||
private final QianFanApi qianFanApi = new QianFanApi(
|
||||
"qS8k8dYr2nXunagK4SSU8Xjj",
|
||||
"pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
|
||||
private final QianFanChatModel chatModel = new QianFanChatModel(qianFanApi,
|
||||
QianFanChatOptions.builder().withModel(QianFanApi.ChatModel.ERNIE_Tiny_8K.getValue()).build()
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
// TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
|
||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
// TODO @芋艿:文心一言,只要带上 system message 就报错,已经各种测试,很莫名!
|
||||
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ZhiPuAiChatModel} 的集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class ZhiPuAiChatModelTests {
|
||||
|
||||
private final ZhiPuAiApi zhiPuAiApi = new ZhiPuAiApi("32f84543e54eee31f8d56b2bd6020573.3vh9idLJZ2ZhxDEs");
|
||||
private final ZhiPuAiChatModel chatModel = new ZhiPuAiChatModel(zhiPuAiApi,
|
||||
ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.ChatModel.GLM_4.getModelName()).build());
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(response -> {
|
||||
// System.out.println(response);
|
||||
System.out.println(response.getResult().getOutput());
|
||||
}).then().block();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link MidjourneyApi} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class MidjourneyApiTests {
|
||||
|
||||
private final MidjourneyApi midjourneyApi = new MidjourneyApi(
|
||||
"https://api.holdai.top/mj",
|
||||
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
|
||||
null);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testImagine() {
|
||||
// 准备参数
|
||||
MidjourneyApi.ImagineRequest request = new MidjourneyApi.ImagineRequest(null,
|
||||
"生成一个小猫,可爱的", null,
|
||||
MidjourneyApi.ImagineRequest.buildState(512, 512, "6.0", MidjourneyApi.ModelEnum.MIDJOURNEY.getModel()));
|
||||
|
||||
// 方法调用
|
||||
MidjourneyApi.SubmitResponse response = midjourneyApi.imagine(request);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testAction() {
|
||||
// 准备参数
|
||||
MidjourneyApi.ActionRequest request = new MidjourneyApi.ActionRequest("1720277033455953",
|
||||
"MJ::JOB::upsample::1::ee267661-ee52-4ced-a530-0343ba95af3b", null);
|
||||
|
||||
// 方法调用
|
||||
MidjourneyApi.SubmitResponse response = midjourneyApi.action(request);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetTaskList() {
|
||||
// 准备参数。该参数可以通过 MidjourneyApi.SubmitResponse 的 result 获取
|
||||
// String taskId = "1720277033455953";
|
||||
String taskId = "1720277214045971";
|
||||
|
||||
// 方法调用
|
||||
List<MidjourneyApi.Notify> taskList = midjourneyApi.getTaskList(Collections.singletonList(taskId));
|
||||
// 打印结果
|
||||
System.out.println(taskList);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
/**
|
||||
* {@link OpenAiImageModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class OpenAiImageModelTests {
|
||||
|
||||
private final OpenAiImageApi imageApi = new OpenAiImageApi(
|
||||
"https://api.holdai.top",
|
||||
"sk-dZEPiVaNcT3FHhef51996bAa0bC74806BeAb620dA5Da10Bf",
|
||||
RestClient.builder());
|
||||
private final OpenAiImageModel imageModel = new OpenAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel(OpenAiImageApi.ImageModel.DALL_E_2.getValue()) // 这个模型比较便宜
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
|
||||
import static cn.iocoder.yudao.framework.ai.image.StabilityAiImageModelTests.viewImage;
|
||||
|
||||
/**
|
||||
* {@link QianFanImageModel} 集成测试类
|
||||
*/
|
||||
public class QianFanImageTests {
|
||||
|
||||
private final QianFanImageApi imageApi = new QianFanImageApi(
|
||||
"qS8k8dYr2nXunagK4SSU8Xjj", "pHGbx51ql2f0hOyabQvSZezahVC3hh3e");
|
||||
private final QianFanImageModel imageModel = new QianFanImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
// 只支持 1024x1024、768x768、768x1024、1024x768、576x1024、1024x576
|
||||
QianFanImageOptions imageOptions = QianFanImageOptions.builder()
|
||||
.withModel(QianFanImageApi.ImageModel.Stable_Diffusion_XL.getValue())
|
||||
.withWidth(1024).withHeight(1024)
|
||||
.withN(1)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("good", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
viewImage(b64Json);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.hutool.core.codec.Base64;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* {@link StabilityAiImageModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class StabilityAiImageModelTests {
|
||||
|
||||
private final StabilityAiApi imageApi = new StabilityAiApi(
|
||||
"sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx");
|
||||
private final StabilityAiImageModel imageModel = new StabilityAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel("stable-diffusion-v1-6")
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("great wall", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
System.out.println(response);
|
||||
viewImage(b64Json);
|
||||
}
|
||||
|
||||
public static void viewImage(String b64Json) {
|
||||
// 创建一个 JFrame
|
||||
JFrame frame = new JFrame("Byte Image Display");
|
||||
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
||||
frame.setSize(800, 600);
|
||||
|
||||
// 创建一个 JLabel 来显示图片
|
||||
byte[] imageBytes = Base64.decode(b64Json);
|
||||
JLabel label = new JLabel(new ImageIcon(imageBytes));
|
||||
|
||||
// 将 JLabel 添加到 JFrame
|
||||
frame.getContentPane().add(label, BorderLayout.CENTER);
|
||||
|
||||
// 显示 JFrame
|
||||
frame.setVisible(true);
|
||||
ThreadUtil.sleep(1, TimeUnit.HOURS);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.utils.Constants;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
|
||||
/**
|
||||
* {@link com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel} 集成测试类
|
||||
*
|
||||
* @author fansili
|
||||
*/
|
||||
public class TongYiImagesModelTest {
|
||||
|
||||
private final ImageSynthesis imageApi = new ImageSynthesis();
|
||||
private final TongYiImagesModel imageModel = new TongYiImagesModel(imageApi);
|
||||
|
||||
static {
|
||||
Constants.apiKey = "sk-Zsd81gZYg7";
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void imageCallTest() {
|
||||
// 准备参数
|
||||
ImageOptions options = OpenAiImageOptions.builder()
|
||||
.withModel(ImageSynthesis.Models.WANX_V1)
|
||||
.withHeight(256).withWidth(256)
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("中国长城!", options);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
|
||||
/**
|
||||
* {@link ZhiPuAiImageModel} 集成测试
|
||||
*/
|
||||
public class ZhiPuAiImageModelTests {
|
||||
|
||||
private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
|
||||
"78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
|
||||
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
|
||||
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package cn.iocoder.yudao.framework.ai.music;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link SunoApi} 集成测试
|
||||
*
|
||||
* @author xiaoxin
|
||||
*/
|
||||
public class SunoApiTests {
|
||||
|
||||
private final SunoApi sunoApi = new SunoApi("https://suno-3tah0ycyt-status2xxs-projects.vercel.app");
|
||||
// private final SunoApi sunoApi = new SunoApi("http://127.0.0.1:3001");
|
||||
|
||||
@Test // 描述模式
|
||||
@Disabled
|
||||
public void testGenerate() {
|
||||
// 准备参数
|
||||
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
|
||||
"happy music",
|
||||
"chirp-v3-5",
|
||||
false);
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.generate(generateRequest);
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test // 歌词模式
|
||||
@Disabled
|
||||
public void testCustomGenerate() {
|
||||
// 准备参数
|
||||
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
|
||||
"创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。",
|
||||
"Happy",
|
||||
"Happy Song",
|
||||
"chirp-v3.5",
|
||||
false,
|
||||
false);
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.customGenerate(generateRequest);
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGenerateLyrics() {
|
||||
// 调用方法
|
||||
SunoApi.LyricsData lyricsData = sunoApi.generateLyrics("A soothing lullaby");
|
||||
// 打印结果
|
||||
System.out.println(lyricsData);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetMusicList() {
|
||||
// 准备参数
|
||||
// String id = "d460ddda-7c87-4f34-b751-419b08a590ca";
|
||||
String id = "584729e5-0fe9-4157-86da-1b4803ff42bf";
|
||||
|
||||
// 调用方法
|
||||
List<SunoApi.MusicData> musicList = sunoApi.getMusicList(List.of(id));
|
||||
// 打印结果
|
||||
System.out.println(musicList);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testGetLimitUsage() {
|
||||
// 调用方法
|
||||
SunoApi.LimitUsageData limitUsageData = sunoApi.getLimitUsage();
|
||||
// 打印结果
|
||||
System.out.println(limitUsageData);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user