mirror of
https://github.com/slhaf/Partner.git
synced 2026-05-12 16:53:04 +08:00
refactor(memory): manage state serialization via StateCenter in MemoryUnit, and support optional loading on register in StateCenter
This commit is contained in:
@@ -45,7 +45,13 @@ public class MemoryCore implements StateSerializable {
|
|||||||
|
|
||||||
@CapabilityMethod
|
@CapabilityMethod
|
||||||
public MemoryUnit getMemoryUnit(String unitId) {
|
public MemoryUnit getMemoryUnit(String unitId) {
|
||||||
return memoryUnits.computeIfAbsent(unitId, MemoryUnit::new);
|
MemoryUnit unit = memoryUnits.computeIfAbsent(unitId, id -> {
|
||||||
|
MemoryUnit newUnit = new MemoryUnit(id);
|
||||||
|
newUnit.register();
|
||||||
|
return newUnit;
|
||||||
|
});
|
||||||
|
unit.load();
|
||||||
|
return unit;
|
||||||
}
|
}
|
||||||
|
|
||||||
@CapabilityMethod
|
@CapabilityMethod
|
||||||
|
|||||||
@@ -1,18 +1,11 @@
|
|||||||
package work.slhaf.partner.core.memory.pojo;
|
package work.slhaf.partner.core.memory.pojo;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Getter;
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import work.slhaf.partner.framework.agent.common.entity.PersistableObject;
|
|
||||||
|
|
||||||
import java.io.Serial;
|
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@Getter
|
||||||
@Data
|
public class MemorySlice implements Comparable<MemorySlice> {
|
||||||
public class MemorySlice extends PersistableObject implements Comparable<MemorySlice> {
|
|
||||||
|
|
||||||
@Serial
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
private final String id;
|
private final String id;
|
||||||
private final Integer startIndex;
|
private final Integer startIndex;
|
||||||
@@ -28,6 +21,18 @@ public class MemorySlice extends PersistableObject implements Comparable<MemoryS
|
|||||||
this.summary = summary;
|
this.summary = summary;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private MemorySlice(String id, Integer startIndex, Integer endIndex, String summary, Long timestamp) {
|
||||||
|
this.id = id;
|
||||||
|
this.startIndex = startIndex;
|
||||||
|
this.endIndex = endIndex;
|
||||||
|
this.summary = summary;
|
||||||
|
this.timestamp = timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MemorySlice restore(String id, Integer startIndex, Integer endIndex, String summary, Long timestamp) {
|
||||||
|
return new MemorySlice(id, startIndex, endIndex, summary, timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compareTo(MemorySlice memorySlice) {
|
public int compareTo(MemorySlice memorySlice) {
|
||||||
if (memorySlice.getTimestamp() > this.getTimestamp()) {
|
if (memorySlice.getTimestamp() > this.getTimestamp()) {
|
||||||
|
|||||||
@@ -1,24 +1,125 @@
|
|||||||
package work.slhaf.partner.core.memory.pojo;
|
package work.slhaf.partner.core.memory.pojo;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson2.JSONArray;
|
||||||
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
import work.slhaf.partner.framework.agent.model.pojo.Message;
|
||||||
|
import work.slhaf.partner.framework.agent.state.State;
|
||||||
|
import work.slhaf.partner.framework.agent.state.StateSerializable;
|
||||||
|
import work.slhaf.partner.framework.agent.state.StateValue;
|
||||||
|
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public class MemoryUnit {
|
public class MemoryUnit implements StateSerializable {
|
||||||
|
|
||||||
private final String id;
|
private final String id;
|
||||||
private final List<Message> conversationMessages = new ArrayList<>();
|
private final List<Message> conversationMessages = new ArrayList<>();
|
||||||
private Long timestamp;
|
private Long timestamp = 0L;
|
||||||
private final List<MemorySlice> slices = new ArrayList<>();
|
private final List<MemorySlice> slices = new ArrayList<>();
|
||||||
|
|
||||||
public MemoryUnit(String id) {
|
public MemoryUnit(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
|
this.register();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateTimestamp() {
|
public void updateTimestamp() {
|
||||||
timestamp = System.currentTimeMillis();
|
timestamp = System.currentTimeMillis();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull Path statePath() {
|
||||||
|
return Path.of("core", "memory", "memory-unit" + id + ".json");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void load(@NotNull JSONObject state) {
|
||||||
|
Long loadedTimestamp = state.getLong("update_timestamp");
|
||||||
|
this.timestamp = loadedTimestamp != null ? loadedTimestamp : 0L;
|
||||||
|
|
||||||
|
this.conversationMessages.clear();
|
||||||
|
this.slices.clear();
|
||||||
|
|
||||||
|
JSONArray messageArray = state.getJSONArray("conversation_messages");
|
||||||
|
if (messageArray != null) {
|
||||||
|
for (int i = 0; i < messageArray.size(); i++) {
|
||||||
|
JSONObject messageObject = messageArray.getJSONObject(i);
|
||||||
|
if (messageObject == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String role = messageObject.getString("role");
|
||||||
|
String content = messageObject.getString("content");
|
||||||
|
if (role == null || content == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Message message = new Message(Message.Character.fromValue(role), content);
|
||||||
|
|
||||||
|
this.conversationMessages.add(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sliceArray = state.getJSONArray("memory_slices");
|
||||||
|
if (sliceArray != null) {
|
||||||
|
for (int i = 0; i < sliceArray.size(); i++) {
|
||||||
|
JSONObject sliceObject = sliceArray.getJSONObject(i);
|
||||||
|
if (sliceObject == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String sliceId = sliceObject.getString("id");
|
||||||
|
Integer startIndex = sliceObject.getInteger("start_index");
|
||||||
|
Integer endIndex = sliceObject.getInteger("end_index");
|
||||||
|
String summary = sliceObject.getString("summary");
|
||||||
|
Long createdTimestamp = sliceObject.getLong("created_timestamp");
|
||||||
|
|
||||||
|
if (sliceId == null || startIndex == null || endIndex == null || summary == null || createdTimestamp == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MemorySlice slice = MemorySlice.restore(sliceId, startIndex, endIndex, summary, createdTimestamp);
|
||||||
|
|
||||||
|
this.slices.add(slice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @NotNull State convert() {
|
||||||
|
State state = new State();
|
||||||
|
state.append("id", StateValue.str(id));
|
||||||
|
state.append("update_timestamp", StateValue.num(timestamp));
|
||||||
|
|
||||||
|
List<StateValue.Obj> convertedMessageList = conversationMessages.stream().map(message -> {
|
||||||
|
Map<String, StateValue> convertedMap = Map.of(
|
||||||
|
"role", StateValue.str(message.roleValue()),
|
||||||
|
"content", StateValue.str(message.getContent())
|
||||||
|
);
|
||||||
|
return StateValue.obj(convertedMap);
|
||||||
|
}).toList();
|
||||||
|
state.append("conversation_messages", StateValue.arr(convertedMessageList));
|
||||||
|
|
||||||
|
List<StateValue.Obj> convertedSliceList = slices.stream().map(slice -> {
|
||||||
|
Map<String, StateValue> convertedMap = Map.of(
|
||||||
|
"id", StateValue.str(slice.getId()),
|
||||||
|
"start_index", StateValue.num(slice.getStartIndex()),
|
||||||
|
"end_index", StateValue.num(slice.getEndIndex()),
|
||||||
|
"summary", StateValue.str(slice.getSummary()),
|
||||||
|
"created_timestamp", StateValue.num(slice.getTimestamp())
|
||||||
|
);
|
||||||
|
return StateValue.obj(convertedMap);
|
||||||
|
}).toList();
|
||||||
|
state.append("memory_slices", StateValue.arr(convertedSliceList));
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean autoLoadOnRegister() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package work.slhaf.partner.framework.agent.state
|
|||||||
|
|
||||||
import com.alibaba.fastjson2.JSONArray
|
import com.alibaba.fastjson2.JSONArray
|
||||||
import com.alibaba.fastjson2.JSONObject
|
import com.alibaba.fastjson2.JSONObject
|
||||||
|
import org.slf4j.LoggerFactory
|
||||||
import work.slhaf.partner.framework.agent.config.ConfigCenter
|
import work.slhaf.partner.framework.agent.config.ConfigCenter
|
||||||
import java.nio.file.Files
|
import java.nio.file.Files
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
@@ -13,7 +14,9 @@ import kotlin.io.path.writeText
|
|||||||
|
|
||||||
object StateCenter {
|
object StateCenter {
|
||||||
|
|
||||||
private val stateRegistry = ConcurrentHashMap<Path, StateSerializable>()
|
private val log = LoggerFactory.getLogger(StateCenter::class.java)
|
||||||
|
|
||||||
|
private val stateRegistry = ConcurrentHashMap<Path, StateRecord>()
|
||||||
|
|
||||||
fun register(stateSerializable: StateSerializable): JSONObject? {
|
fun register(stateSerializable: StateSerializable): JSONObject? {
|
||||||
val relativePath = stateSerializable.statePath().normalize()
|
val relativePath = stateSerializable.statePath().normalize()
|
||||||
@@ -23,8 +26,9 @@ object StateCenter {
|
|||||||
val finalStatePath = stateDir.resolve(relativePath).normalize()
|
val finalStatePath = stateDir.resolve(relativePath).normalize()
|
||||||
check(finalStatePath.startsWith(stateDir)) { "StatePath escapes stateDir" }
|
check(finalStatePath.startsWith(stateDir)) { "StatePath escapes stateDir" }
|
||||||
|
|
||||||
val previous = stateRegistry.putIfAbsent(finalStatePath, stateSerializable)
|
val stateRecord = StateRecord(stateSerializable)
|
||||||
check(previous == null || previous === stateSerializable) {
|
val previous = stateRegistry.putIfAbsent(finalStatePath, stateRecord)
|
||||||
|
check(previous == null || previous.serializable === stateSerializable) {
|
||||||
"StatePath already registered: $finalStatePath"
|
"StatePath already registered: $finalStatePath"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,15 +39,43 @@ object StateCenter {
|
|||||||
check(finalStatePath.isRegularFile()) { "StatePath must point to a regular file: $finalStatePath" }
|
check(finalStatePath.isRegularFile()) { "StatePath must point to a regular file: $finalStatePath" }
|
||||||
check(finalStatePath.toFile().canRead()) { "StateFile must be readable: $finalStatePath" }
|
check(finalStatePath.toFile().canRead()) { "StateFile must be readable: $finalStatePath" }
|
||||||
|
|
||||||
|
if (!stateSerializable.autoLoadOnRegister()) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
stateRecord.loaded = true
|
||||||
|
|
||||||
return JSONObject.parseObject(finalStatePath.readText())
|
return JSONObject.parseObject(finalStatePath.readText())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun load(path: Path) {
|
||||||
|
val finalStatePath = ConfigCenter.paths.stateDir.normalize().resolve(path).normalize()
|
||||||
|
if (!stateRegistry.containsKey(path)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val record = stateRegistry[finalStatePath] ?: return
|
||||||
|
record.loaded = true
|
||||||
|
if (!finalStatePath.exists()) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
val json = JSONObject.parseObject(finalStatePath.readText())
|
||||||
|
record.serializable.load(json)
|
||||||
|
} catch (_: Exception) {
|
||||||
|
log.warn("StateCenter loading failed: $path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun save() {
|
fun save() {
|
||||||
stateRegistry.forEach { (path, state) ->
|
stateRegistry.forEach { (path, record) ->
|
||||||
|
if (!record.loaded) {
|
||||||
|
return@forEach
|
||||||
|
}
|
||||||
path.parent?.let(Files::createDirectories)
|
path.parent?.let(Files::createDirectories)
|
||||||
path.writeText(state.convert().toString())
|
path.writeText(record.serializable.convert().toString())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface StateSerializable {
|
interface StateSerializable {
|
||||||
@@ -57,9 +89,27 @@ interface StateSerializable {
|
|||||||
|
|
||||||
fun statePath(): Path
|
fun statePath(): Path
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 手动加载状态数据
|
||||||
|
*/
|
||||||
|
fun load() {
|
||||||
|
StateCenter.load(statePath())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 状态加载逻辑
|
||||||
|
*/
|
||||||
fun load(state: JSONObject)
|
fun load(state: JSONObject)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 数据转换为状态逻辑
|
||||||
|
*/
|
||||||
fun convert(): State
|
fun convert(): State
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否在注册时即触发一次加载
|
||||||
|
*/
|
||||||
|
fun autoLoadOnRegister(): Boolean = true
|
||||||
}
|
}
|
||||||
|
|
||||||
class State {
|
class State {
|
||||||
@@ -68,8 +118,6 @@ class State {
|
|||||||
|
|
||||||
fun append(key: String, value: StateValue) = json.put(key, value.toJsonValue())
|
fun append(key: String, value: StateValue) = json.put(key, value.toJsonValue())
|
||||||
|
|
||||||
fun toJson(): JSONObject = json
|
|
||||||
|
|
||||||
override fun toString(): String = json.toString()
|
override fun toString(): String = json.toString()
|
||||||
|
|
||||||
private fun StateValue.toJsonValue(): Any = when (this) {
|
private fun StateValue.toJsonValue(): Any = when (this) {
|
||||||
@@ -105,3 +153,8 @@ sealed interface StateValue {
|
|||||||
fun obj(value: Map<String, StateValue>) = Obj(value)
|
fun obj(value: Map<String, StateValue>) = Obj(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data class StateRecord(
|
||||||
|
val serializable: StateSerializable,
|
||||||
|
var loaded: Boolean = false
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user