Enhance agent chat functionality with ROM context support and structured message rendering
This commit is contained in:
@@ -1,18 +1,159 @@
|
||||
#include "cli/service/agent/conversational_agent_service.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "cli/service/ai/service_factory.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
namespace yaze {
|
||||
namespace cli {
|
||||
namespace agent {
|
||||
|
||||
namespace {
|
||||
|
||||
std::string TrimWhitespace(const std::string& input) {
|
||||
auto begin = std::find_if_not(input.begin(), input.end(),
|
||||
[](unsigned char c) { return std::isspace(c); });
|
||||
auto end = std::find_if_not(input.rbegin(), input.rend(),
|
||||
[](unsigned char c) { return std::isspace(c); })
|
||||
.base();
|
||||
if (begin >= end) {
|
||||
return "";
|
||||
}
|
||||
return std::string(begin, end);
|
||||
}
|
||||
|
||||
std::string JsonValueToString(const nlohmann::json& value) {
|
||||
if (value.is_string()) {
|
||||
return value.get<std::string>();
|
||||
}
|
||||
if (value.is_boolean()) {
|
||||
return value.get<bool>() ? "true" : "false";
|
||||
}
|
||||
if (value.is_number()) {
|
||||
return value.dump();
|
||||
}
|
||||
if (value.is_null()) {
|
||||
return "null";
|
||||
}
|
||||
return value.dump();
|
||||
}
|
||||
|
||||
std::set<std::string> CollectObjectKeys(const nlohmann::json& array) {
|
||||
std::set<std::string> keys;
|
||||
for (const auto& item : array) {
|
||||
if (!item.is_object()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& [key, _] : item.items()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
std::optional<ChatMessage::TableData> BuildTableData(const nlohmann::json& data) {
|
||||
using TableData = ChatMessage::TableData;
|
||||
|
||||
if (data.is_object()) {
|
||||
TableData table;
|
||||
table.headers = {"Key", "Value"};
|
||||
table.rows.reserve(data.size());
|
||||
for (const auto& [key, value] : data.items()) {
|
||||
table.rows.push_back({key, JsonValueToString(value)});
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
if (data.is_array()) {
|
||||
TableData table;
|
||||
if (data.empty()) {
|
||||
table.headers = {"Value"};
|
||||
return table;
|
||||
}
|
||||
|
||||
const bool all_objects = std::all_of(data.begin(), data.end(), [](const nlohmann::json& item) {
|
||||
return item.is_object();
|
||||
});
|
||||
|
||||
if (all_objects) {
|
||||
auto keys = CollectObjectKeys(data);
|
||||
if (keys.empty()) {
|
||||
table.headers = {"Value"};
|
||||
for (const auto& item : data) {
|
||||
table.rows.push_back({JsonValueToString(item)});
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
table.headers.assign(keys.begin(), keys.end());
|
||||
table.rows.reserve(data.size());
|
||||
for (const auto& item : data) {
|
||||
std::vector<std::string> row;
|
||||
row.reserve(table.headers.size());
|
||||
for (const auto& key : table.headers) {
|
||||
if (item.contains(key)) {
|
||||
row.push_back(JsonValueToString(item.at(key)));
|
||||
} else {
|
||||
row.emplace_back("-");
|
||||
}
|
||||
}
|
||||
table.rows.push_back(std::move(row));
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
table.headers = {"Value"};
|
||||
table.rows.reserve(data.size());
|
||||
for (const auto& item : data) {
|
||||
table.rows.push_back({JsonValueToString(item)});
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
ChatMessage CreateMessage(ChatMessage::Sender sender, const std::string& content) {
|
||||
ChatMessage message;
|
||||
message.sender = sender;
|
||||
message.message = content;
|
||||
message.timestamp = absl::Now();
|
||||
|
||||
if (sender == ChatMessage::Sender::kAgent) {
|
||||
const std::string trimmed = TrimWhitespace(content);
|
||||
if (!trimmed.empty() && (trimmed.front() == '{' || trimmed.front() == '[')) {
|
||||
try {
|
||||
nlohmann::json parsed = nlohmann::json::parse(trimmed);
|
||||
message.table_data = BuildTableData(parsed);
|
||||
message.json_pretty = parsed.dump(2);
|
||||
} catch (const nlohmann::json::parse_error&) {
|
||||
// Ignore parse errors, fall back to raw text.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConversationalAgentService::ConversationalAgentService() {
|
||||
ai_service_ = CreateAIService();
|
||||
}
|
||||
|
||||
void ConversationalAgentService::SetRomContext(Rom* rom) {
|
||||
rom_context_ = rom;
|
||||
tool_dispatcher_.SetRomContext(rom_context_);
|
||||
}
|
||||
|
||||
absl::StatusOr<ChatMessage> ConversationalAgentService::SendMessage(
|
||||
const std::string& message) {
|
||||
if (message.empty() && history_.empty()) {
|
||||
@@ -21,7 +162,7 @@ absl::StatusOr<ChatMessage> ConversationalAgentService::SendMessage(
|
||||
}
|
||||
|
||||
if (!message.empty()) {
|
||||
history_.push_back({ChatMessage::Sender::kUser, message, absl::Now()});
|
||||
history_.push_back(CreateMessage(ChatMessage::Sender::kUser, message));
|
||||
}
|
||||
|
||||
constexpr int kMaxToolIterations = 4;
|
||||
@@ -46,7 +187,7 @@ absl::StatusOr<ChatMessage> ConversationalAgentService::SendMessage(
|
||||
const std::string& tool_output = tool_result_or.value();
|
||||
if (!tool_output.empty()) {
|
||||
history_.push_back(
|
||||
{ChatMessage::Sender::kAgent, tool_output, absl::Now()});
|
||||
CreateMessage(ChatMessage::Sender::kAgent, tool_output));
|
||||
}
|
||||
executed_tool = true;
|
||||
}
|
||||
@@ -73,8 +214,8 @@ absl::StatusOr<ChatMessage> ConversationalAgentService::SendMessage(
|
||||
response_text.append(absl::StrJoin(agent_response.commands, "\n"));
|
||||
}
|
||||
|
||||
ChatMessage chat_response = {ChatMessage::Sender::kAgent, response_text,
|
||||
absl::Now()};
|
||||
ChatMessage chat_response =
|
||||
CreateMessage(ChatMessage::Sender::kAgent, response_text);
|
||||
history_.push_back(chat_response);
|
||||
return chat_response;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef YAZE_SRC_CLI_SERVICE_AGENT_CONVERSATIONAL_AGENT_SERVICE_H_
|
||||
#define YAZE_SRC_CLI_SERVICE_AGENT_CONVERSATIONAL_AGENT_SERVICE_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -9,14 +10,23 @@
|
||||
#include "cli/service/agent/tool_dispatcher.h"
|
||||
|
||||
namespace yaze {
|
||||
|
||||
class Rom;
|
||||
|
||||
namespace cli {
|
||||
namespace agent {
|
||||
|
||||
struct ChatMessage {
|
||||
enum class Sender { kUser, kAgent };
|
||||
struct TableData {
|
||||
std::vector<std::string> headers;
|
||||
std::vector<std::vector<std::string>> rows;
|
||||
};
|
||||
Sender sender;
|
||||
std::string message;
|
||||
absl::Time timestamp;
|
||||
std::optional<std::string> json_pretty;
|
||||
std::optional<TableData> table_data;
|
||||
};
|
||||
|
||||
class ConversationalAgentService {
|
||||
@@ -29,10 +39,14 @@ class ConversationalAgentService {
|
||||
// Get the full chat history.
|
||||
const std::vector<ChatMessage>& GetHistory() const;
|
||||
|
||||
// Provide the service with a ROM context for tool execution.
|
||||
void SetRomContext(Rom* rom);
|
||||
|
||||
private:
|
||||
std::vector<ChatMessage> history_;
|
||||
std::unique_ptr<AIService> ai_service_;
|
||||
ToolDispatcher tool_dispatcher_;
|
||||
Rom* rom_context_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace agent
|
||||
|
||||
@@ -35,9 +35,9 @@ absl::StatusOr<std::string> ToolDispatcher::Dispatch(
|
||||
|
||||
absl::Status status;
|
||||
if (tool_call.tool_name == "resource-list") {
|
||||
status = HandleResourceListCommand(args);
|
||||
status = HandleResourceListCommand(args, rom_context_);
|
||||
} else if (tool_call.tool_name == "dungeon-list-sprites") {
|
||||
status = HandleDungeonListSpritesCommand(args);
|
||||
status = HandleDungeonListSpritesCommand(args, rom_context_);
|
||||
} else {
|
||||
status = absl::UnimplementedError(
|
||||
absl::StrFormat("Unknown tool: %s", tool_call.tool_name));
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "cli/service/ai/common.h"
|
||||
|
||||
namespace yaze {
|
||||
|
||||
class Rom;
|
||||
|
||||
namespace cli {
|
||||
namespace agent {
|
||||
|
||||
@@ -15,6 +18,11 @@ class ToolDispatcher {
|
||||
|
||||
// Execute a tool call and return the result as a string.
|
||||
absl::StatusOr<std::string> Dispatch(const ToolCall& tool_call);
|
||||
// Provide a ROM context for tool calls that require ROM access.
|
||||
void SetRomContext(Rom* rom) { rom_context_ = rom; }
|
||||
|
||||
private:
|
||||
Rom* rom_context_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace agent
|
||||
|
||||
Reference in New Issue
Block a user