Spring AI 系列之五 - 聊天记忆之自定义

之前做个几个大模型的应用,都是使用Python语言,后来有一个项目使用了Java,并使用了Spring AI框架。随着Spring AI不断地完善,最近它发布了1.0正式版,意味着它已经能很好的作为企业级生产环境的使用。对于Java开发者来说真是一个福音,其功能已经能满足基于大模型开发企业级应用。借着这次机会,给大家分享一下Spring AI框架。

注意由于框架不同版本改造会有些使用的不同,因此本次系列中使用基本框架是 Spring AI-1.0.0,JDK版本使用的是19
代码参考: https://github.com/forever1986/springai-study

上一章讲解Spring AI的聊天记忆功能,包括入门、类型以及存储方式,这一章通过实现一个自定义的Redis存储来存储聊天记忆。

1 原理分析

在上一章中,讲到Spring AI实现聊天记忆存储是通过以下代码实现:

.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build()) // 通过不同角色Message方式传递聊天记忆

其中ChatMemory的实现类MessageWindowChatMemory中就需要配置ChatMemoryRepository。因此如果要实现不同类型的存储,那么可以通过两种方式扩展:一个是实现ChatMemory接口;一个是实现ChatMemoryRepository接口。

下面是ChatMemory接口需要实现的方法

public interface ChatMemory {
	/**
	 * 将单次的聊天记录保存到对话中
	 */
	void add(String conversationId, List<Message> messages);

	/**
	 * 通过对话id,获取对话的聊天记录
	 */
	List<Message> get(String conversationId);

	/**
	 * 清除某次聊天记录
	 */
	void clear(String conversationId);
}

以下是 ChatMemory 的实现类 MessageWindowChatMemory 的部分实现代码,如下:

/**
 * 获取历史的聊天记录,通过process处理(主要过滤重复以及System角色的消息),最后全部调用chatMemoryRepository保存
 */
@Override
public void add(String conversationId, List<Message> messages) {
	Assert.hasText(conversationId, "conversationId cannot be null or empty");
	Assert.notNull(messages, "messages cannot be null");
	Assert.noNullElements(messages, "messages cannot contain null elements");

	List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
	List<Message> processedMessages = process(memoryMessages, messages);
	this.chatMemoryRepository.saveAll(conversationId, processedMessages);
}

/**
 * 使用chatMemoryRepository获取某个对话id的聊天记录
 */
@Override
public List<Message> get(String conversationId) {
	Assert.hasText(conversationId, "conversationId cannot be null or empty");
	return this.chatMemoryRepository.findByConversationId(conversationId);
}

/**
 * 使用chatMemoryRepository清除某个对话id的聊天记录
 */
@Override
public void clear(String conversationId) {
	Assert.hasText(conversationId, "conversationId cannot be null or empty");
	this.chatMemoryRepository.deleteByConversationId(conversationId);
}

从上面MessageWindowChatMemory可知,其最终都是使用ChatMemoryRepository来实现,因此也可以实现ChatMemoryRepository接口的方式来实现Redis存储,下面是ChatMemoryRepository接口的源码:

public interface ChatMemoryRepository {

	/**
 	 * 返回所有的对话id
 	 */
	List<String> findConversationIds();

	/**
 	 * 返回对话id的所有聊天记录
 	 */
	List<Message> findByConversationId(String conversationId);

	/**
	 * 替换原先保存对话的所有聊天记录
	 */
	void saveAll(String conversationId, List<Message> messages);

	/**
 	 * 删除对话id的所有聊天记录
 	 */
	void deleteByConversationId(String conversationId);

}

说明:因此可以通过实现ChatMemory接口或者ChatMemoryRepository接口来自定义存储,下面基于实现ChatMemoryRepository接口方式来演示

2 自定义聊天记忆存储-Redis

代码参考lesson06子模块

2.1 前期准备

1)准备一个Redis服务器

2)在springai-study父项目下,新建lesson06子模块,其pom引入

<dependencies>
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-starter-model-zhipuai</artifactId>
    </dependency>
    <!-- 引入redis依赖 -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-data-redis</artifactId>
    </dependency>
    <!-- 使用lettuce连接池,需要引入commons-pool2 -->
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-pool2</artifactId>
    </dependency>
</dependencies>

3)在resources目录下,创建application.properties配置文件

# 聊天模型
spring.ai.zhipuai.api-key=你的智谱模型的API KEY
spring.ai.zhipuai.chat.options.model=GLM-4-Flash-250414
spring.ai.zhipuai.chat.options.temperature=0.7

# redis配置
spring.data.redis.database=1
spring.data.redis.host=127.0.0.1
spring.data.redis.port=6379
spring.data.redis.lettuce.pool.max-active=10
spring.data.redis.lettuce.pool.max-idle=10
spring.data.redis.lettuce.pool.min-idle=0

4)创建启动类Lesson06Application:

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class Lesson06Application {

    public static void main(String[] args) {
        SpringApplication.run(Lesson06Application.class, args);
    }

}

2.2 redis配置

1)新建redis序列化类RedisMessageSerializer :

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.data.redis.serializer.RedisSerializer;

import java.io.IOException;

public class RedisMessageSerializer implements RedisSerializer<Message> {

    private final ObjectMapper objectMapper;
    private final JsonDeserializer<Message> messageDeserializer;

    public RedisMessageSerializer(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
        this.messageDeserializer = new JsonDeserializer<>() {
            @Override
            public Message deserialize(JsonParser jp, DeserializationContext ctx)
                    throws IOException {
                ObjectNode root = jp.readValueAsTree();
                String type = root.get("messageType").asText();

                return switch (type) {
                    case "USER" -> new UserMessage(root.get("text").asText());
                    case "ASSISTANT" -> new AssistantMessage(root.get("text").asText());
                    case "SYSTEM" -> new SystemMessage(root.get("text").asText());
                    default -> throw new UnsupportedOperationException("消息类型错误");
                };
            }
        };
    }

    @Override
    public byte[] serialize(Message message) {
        try {
            return objectMapper.writeValueAsBytes(message);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("序列化失败", e);
        }
    }

    @Override
    public Message deserialize(byte[] bytes) {
        if (bytes == null || bytes.length == 0) {
            return null;
        }
        try {
            return messageDeserializer.deserialize(objectMapper.getFactory().createParser(bytes), objectMapper.getDeserializationContext());
        } catch (Exception e) {
            throw new RuntimeException("反序列化识别", e);
        }
    }
}

2)创建RedisConfiguration 配置类设置redisTemplate:

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;

@Configuration
public class RedisConfiguration {

    @Bean
    @ConditionalOnMissingBean({RedisTemplate.class})
    public RedisTemplate redisTemplate(RedisConnectionFactory factory) {
        RedisTemplate<String, Object> template = new RedisTemplate();
        template.setConnectionFactory(factory);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        RedisMessageSerializer redisMessageSerializer = new RedisMessageSerializer(om);
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        template.setKeySerializer(stringRedisSerializer);
        template.setHashKeySerializer(stringRedisSerializer);
        template.setValueSerializer(redisMessageSerializer);
        template.setHashValueSerializer(redisMessageSerializer);
        template.afterPropertiesSet();
        return template;
    }

}

2.3 自定义redis的存储

1)自定义Repository

import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Repository;

import java.util.List;
import java.util.Set;

@Repository
public class RedisChatMemoryRepository implements ChatMemoryRepository {

    private static final String REDIS_KEY_PREFIX = "chatmemory:";

    private final RedisTemplate<String, Message> redisTemplate;

    public RedisChatMemoryRepository(RedisTemplate<String, Message> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Override
    public List<String> findConversationIds() {
        Set<String> keys = this.redisTemplate.keys(REDIS_KEY_PREFIX+"*");
        return keys.stream().toList();
    }

    @Override
    public List<Message> findByConversationId(String conversationId) {
        return this.redisTemplate.opsForList().range(REDIS_KEY_PREFIX+conversationId, 0, -1);
    }

    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        this.redisTemplate.delete(REDIS_KEY_PREFIX+conversationId);// 由于每次的messages都会获取到之前的数据,因此要先删除,在插入
        this.redisTemplate.opsForList().rightPushAll(REDIS_KEY_PREFIX+conversationId, messages);
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        redisTemplate.delete(REDIS_KEY_PREFIX + conversationId);
    }
}

2)创建RedisMemoryController 进行演示

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class RedisMemoryController {

    private ChatClient chatClient;

    public RedisMemoryController(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
        this.chatClient = chatClientBuilder
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build()) // 通过不同角色Message方式传递聊天记忆
                .build();
    }

    /**
     * @param message 问题
     * @param conversationId 聊天记忆的id
     */
    @GetMapping("/ai/redismemory")
    public String memory(@RequestParam(value = "message", required = true) String message
            , @RequestParam(value = "conversationId", required = true) Integer conversationId) {
        return this.chatClient.prompt()
                .user(message)
                .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
                .call()
                .content();
    }

}

2.4 演示结果

1)请求以下url

http://localhost:8080/ai/redismemory?message=给我推荐10部电影&conversationId=1

在这里插入图片描述

2)查看redis数据库,可以看到记录已经存储在redis数据库中

在这里插入图片描述

2)继续访问地址,可以看到是失效的

http://localhost:8080/ai/redismemory?message=给我推荐最好的一部&conversationId=1

结语:本章通过分析Spring AI的聊天记忆存储原理,并通过自定义Redis存储作为演示。下一章将讲解如何让大模型使用工具。

Spring AI系列上一章:《Spring AI 系列之四 - 聊天记忆之入门

Spring AI系列下一章:《Spring AI 系列之六 - 工具调用

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

linmoo2006

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值