文章目录

参考先看看效果文心千帆创建应用思路步骤与代码

如题,第一次用websocket,做了个这玩意,只做了上下文的聊天,没做流式。 中间还有个低级报错但卡了好久,具体可以看【错误记录】websocket连接失败,但后端毫无反应,还有【错误记录】ruoyi-vue@Autowired注入自定义mapper时为null解决 ,感兴趣可前往观看。 实际上我后端用的是ruoyi-vue,前端用的ruoyi-app,但不重要。因为功能就是基于websocket和文心一言千帆大模型的接口,完全可以独立出来。 每个新建的账号会送一张20元的代金券,期限一个月内。而聊天服务接口单价约1分/千token,总之用来练手肯定够用了。

参考

文档中心-ERNIE-Bot-turbo 百度文心一言接入教程 若依插件-集成websocket实现简单通信

先看看效果

大致这样。

2023.10.13更新:昨天和朋友聊了一下,发现他的想法和我的不同——根本不用实体类去保存解析复杂的json,直接保存消息内容。有一说一,在这个小demo这里,确实可以更快更简单的实现,因为这个demo最耗时的就是看又臭又长的参数,然后写请求体和返回值的实体类,至少请求体实体类是可以不写的。

下面进入正题。

文心千帆创建应用

文心一言,大概是这里,先创建个账号,进控制台创建一个应用(有一个apikey和secretkey,有用),开通一个聊天服务(我开通的是ErnieBot-turbo),就可以了。具体有点忘了,大家可以参考其他博客。其次官方有给一些参考,API调用指南、在线测试平台,第二个链接可以对自己开通的聊天服务进行测试。其中也有一个分类是“技术文档”和“示例代码”,技术文档里边有普通/流式的请求/响应的参数和示例(如果比较小不容易看,文档中心-ERNIE-Bot-turbo也有),示例代码就是请求的各个语言的示例代码。

思路

有三个角色,大模型 ←→ 后端 ←→ 前端。

大模型:接受后端发过来的消息,返回响应消息 后端:接受前端发过来的消息,封装发给大模型;接收大模型返回的消息,回给后端;发送的消息和返回的消息都要保存到数据库 前端:发送消息,接受后端返回的响应消息,实时回显在聊天页面。

显然,websocket用在前后端之间进行交互,后端类似一个中间人,前端是一个用户,大模型是ai服务。

步骤与代码

实现websocket相关 1.1 注册到spring@Configuration

public class WebSocketConfig {

@Bean

public ServerEndpointExporter serverEndpointExporter() {

return new ServerEndpointExporter();

}

}

1.2 实现一个WebSocket的服务(别看这么长,其实参考了若依插件-集成websocket实现简单通信,但没涉及信号量之类所以没什么用,除了onMessage外,其他如onOpen打印一条消息就行了,更多如WebSocketUsers可以去链接那下载)@CrossOrigin

@Component

@ServerEndpoint("/websocket/message")

public class WebSocketServer {

private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class);

/**

* WebSocketServer 日志控制器

*/

private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);

/**

* 默认最多允许同时在线人数100

*/

public static int socketMaxOnlineCount = 100;

private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);

/**

* 连接建立成功调用的方法

*/

@OnOpen

public void onOpen(Session session) throws Exception {

boolean semaphoreFlag = false;

// 尝试获取信号量

semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);

if (!semaphoreFlag) {

// 未获取到信号量

LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);

WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);

session.close();

} else {

// 添加用户

WebSocketUsers.put(session.getId(), session);

LOGGER.info("\n 建立连接 - {}", session);

LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size());

WebSocketUsers.sendMessageToUserByText(session, "连接成功");

}

}

/**

* 连接关闭时处理

*/

@OnClose

public void onClose(Session session) {

LOGGER.info("\n 关闭连接 - {}", session);

// 移除用户

WebSocketUsers.remove(session.getId());

// 获取到信号量则需释放

SemaphoreUtils.release(socketSemaphore);

}

/**

* 抛出异常时处理

*/

@OnError

public void onError(Session session, Throwable exception) throws Exception {

if (session.isOpen()) {

// 关闭连接

session.close();

}

String sessionId = session.getId();

LOGGER.info("\n 连接异常 - {}", sessionId);

LOGGER.info("\n 异常信息 - {}", exception);

// 移出用户

WebSocketUsers.remove(sessionId);

// 获取到信号量则需释放

SemaphoreUtils.release(socketSemaphore);

}

/**

* 服务器接收到客户端消息时调用的方法

*/

@OnMessage

public void onMessage(String message, Session session) {

// 首先,接收到一条消息

LOGGER.info("\n 收到消息 - {}", message);

// 1. 调用大模型API,把上下文和这次问题传入,得到回复

BigModelService bigModelService = new BigModelService();

TurboResponse response = bigModelService.callModelAPI(session.getId(),message);

if (response == null) {

WebSocketUsers.sendMessageToUserByText(session, "抱歉,似乎出了点问题,请联系管理员");

return;

}

WebSocketUsers.sendMessageToUserByText(session, response.getResult());

}

}

实现请求接口相关 2.1 先写实体类,包括BaiduChatMessage(最基本的聊天消息)、ErnieBotTurboParam(ErnieBot-Turbo的请求参数,包括了List)TurboResponse(请求返回结果对应的实体类)@Data

@SuperBuilder

@NoArgsConstructor

@AllArgsConstructor

public class BaiduChatMessage implements Serializable {

private String role;

private String content;

}

@Data

@SuperBuilder

public class ErnieBotTurboParam implements Serializable {

/**

* 聊天上下文信息。说明:

* (1)messages成员不能为空,1个成员表示单轮对话,多个成员表示多轮对话

* (2)最后一个message为当前请求的信息,前面的message为历史对话信息

* (3)必须为奇数个成员,成员中message的role必须依次为user、assistant

* (4)最后一个message的content长度(即此轮对话的问题)不能超过2000个字符;如果messages中content总长度大于2000字符,系统会依次遗忘最早的历史会话,直到content的总长度不超过2000个字符

*/

protected List messages;

/**

* 是否以流式接口的形式返回数据,默认false

*/

protected Boolean stream;

/**

* 表示最终用户的唯一标识符,可以监视和检测滥用行为,防止接口恶意调用

*/

protected String user_id;

public boolean isStream() {

return Objects.equals(this.stream, true);

}

public ErnieBotTurboParam(){}

}

@Data

public class TurboResponse implements Serializable {

private String id;

private String object;

private Integer created;

private String sentence_id;

private Boolean is_end;

private Boolean is_truncated;

private String result;

private Boolean need_clear_history;

private Usage usage;

@Data

public static class Usage implements Serializable {

private Integer prompt_tokens;

private Integer completion_tokens;

private Integer total_tokens;

}

}

2.2 请求接口实现(注释很详细就不多说了)public class BigModelService {

private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class);

private static final Logger LOGGER = LoggerFactory.getLogger(BigModelService.class);

private static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();

public static final String API_KEY = "你的apikey";

public static final String SECRET_KEY = "你的secretkey";

static String getAccessToken() throws IOException {

MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");

RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY

+ "&client_secret=" + SECRET_KEY);

Request request = new Request.Builder()

.url("https://aip.baidubce.com/oauth/2.0/token")

.method("POST", body)

.addHeader("Content-Type", "application/x-www-form-urlencoded")

.build();

Response response = HTTP_CLIENT.newCall(request).execute();

// 解析返回的access_token

JSONObject jsonObject = JSONObject.parseObject(response.body().string());

return jsonObject.getString("access_token");

}

public TurboResponse callModelAPI(String sessionId, String message) {

// 1. 构建请求体

// 1.1 调用大模型API,要从数据库去查询上下文

ChatRecord cr = chatRecordMapper.selectChatRecordBySessionId(sessionId);

String records = cr == null ? "{}" : cr.getRecords();

// 1.2 把message加进请求体

// 1.2.1 解析上下文,获取聊天记录,把新的message封装加入到聊天记录中

ErnieBotTurboParam param = JSONObject.parseObject(records, ErnieBotTurboParam.class);

List messages = param.getMessages() == null ? new ArrayList<>() : param.getMessages();

messages.add(BaiduChatMessage.builder().role("user").content(message).build());

// 1.2.2 把messages重新设置到param中

param.setMessages(messages);

try {

// 2. 发出请求,调用大模型API

RequestBody body = RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(param));

Request request = new Request.Builder()

.url("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + getAccessToken())

.method("POST", body)

.addHeader("Content-Type", "application/json")

.build();

Response response = HTTP_CLIENT.newCall(request).execute();

if (response.isSuccessful()) {

// 3. 如果调用成功,

// 3.1 解析返回的聊天回复结果

TurboResponse turboResponse = JSONObject.parseObject(response.body().string(), TurboResponse.class);

LOGGER.info("调用大模型API成功: {}", turboResponse.toString());

// 3.2 将聊天回复结果存入数据库

// 3.2.1 先根据sessionId查询数据库

int count = chatRecordMapper.selectRecordCountBySessionId(sessionId);

// 将ai刚返回的回复追加到param中,再填入chatRecord

BaiduChatMessage aiMessage = BaiduChatMessage.builder()

.role("assistant").content(turboResponse.getResult()).build();

messages.add(aiMessage);

param.setMessages(messages);

if (count == 0) {

// 3.2.2 如果没有记录,则插入

ChatRecord chatRecord = new ChatRecord();

chatRecord.setRecordId(turboResponse.getId());

chatRecord.setSessionId(sessionId);

chatRecord.setRecords(JSONObject.toJSONString(param));

// 插入时应填入create_time字段

chatRecord.setCreateTime(LocalDateTime.now());

chatRecordMapper.insertChatRecord(chatRecord);

} else {

// 3.2.3 如果有记录,则更新

// 3.2.4 先查询出原来的记录的create_time,判断是否超过15min

ChatRecord chat = chatRecordMapper.selectChatRecordBySessionId(sessionId);

LocalDateTime createTime = chat.getCreateTime();

if (LocalDateTime.now().isAfter(createTime.plusMinutes(15))) {

// 2.2.2 如果超过15min,则清除records字段,将新的对话记录追加到records字段中

ChatRecord chatRecord = new ChatRecord();

chatRecord.setRecordId(turboResponse.getId());

chatRecord.setSessionId(sessionId);

chatRecord.setRecords(messages.toString());

// 更新时应填入create_time字段

chatRecord.setCreateTime(LocalDateTime.now());

chatRecordMapper.insertChatRecord(chatRecord);

} else {

ChatRecord chat0 = chatRecordMapper.selectChatRecordBySessionId(sessionId);

// 如果没有超过15min,则将新的对话记录追加到records字段中

ChatRecord chatRecord = new ChatRecord();

chatRecord.setRecordId(turboResponse.getId());

chatRecord.setSessionId(sessionId);

// 解析出原来的records

ChatRecord oldchat = chatRecordMapper.selectChatRecordBySessionId(sessionId);

ErnieBotTurboParam records1 = JSONObject.parseObject(oldchat.getRecords(), ErnieBotTurboParam.class);

// 将新的对话记录追加到records字段中

records1.setMessages(messages);

chatRecord.setRecords(JSONObject.toJSONString(records1));

// 没有15min就不更新create_time字段

// 更新chat_record

chatRecordMapper.updateChatRecord(chatRecord);

}

}

return turboResponse;

} else {

LOGGER.error("调用大模型API失败: {}", response.message());

}

} catch (IOException e) {

LOGGER.error("调用大模型API发生异常:", e);

}

return null;

}

}

持久化 3.1 数据库里建个表CREATE TABLE `chat_record` (

`record_id` varchar(20) NOT NULL COMMENT '记录id',

`session_id` varchar(10) DEFAULT NULL COMMENT '所属用户',

`records` json DEFAULT NULL COMMENT '聊天记录',

`create_time` datetime DEFAULT NULL COMMENT '创建时间(判断过期)',

PRIMARY KEY (`record_id`)

) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='聊天记录表';

3.2 对应实体类@Data

@NoArgsConstructor

@AllArgsConstructor

public class ChatRecord {

private String recordId;

private String sessionId;

private String records;

private LocalDateTime createTime;

}

3.3 再写个mapper就行了@Mapper

public interface ChatRecordMapper {

@Insert("INSERT INTO chat_record (record_id, session_id, records, create_time) " +

"VALUES (#{recordId}, #{sessionId}, #{records}, #{createTime})")

void insertChatRecord(ChatRecord chatRecord);

@Select("SELECT COUNT(*) FROM chat_record WHERE session_id = #{sessionId}")

int selectRecordCountBySessionId(String sessionId);

@Results({ // id是用来给@ResultMap注解引用的,到时候在xml中可以直接使用@ResultMap(value = "chatRecord")

@Result(property = "recordId", column = "record_id"),

@Result(property = "sessionId", column = "session_id"),

@Result(property = "records", column = "records"),

@Result(property = "createTime", column = "create_time")

})

@Select("SELECT * FROM chat_record WHERE session_id = #{sessionId}")

ChatRecord selectChatRecordBySessionId(String sessionId);

@Update("UPDATE chat_record SET records = #{records} WHERE session_id = #{sessionId}")

void updateChatRecord(ChatRecord chatRecord);

}

前端聊天页面与实时的回显 4.1 聊天页面写一个(这里前端是uniapp,样式用到了些colorUI)

4.2 js里写一个websocket(见上4.1的connect())

以上就大功告成了,这玩意还有很多缺漏和细节没做,像现在还是根据会话id去做,没有匹配用户id,15min清除聊天记录,但前端那没清……不过能跑能动就行,本来就是一个小任务,也懒得继续花时间调整。 记录一下,有问题可以交流

精彩链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: