Compare commits

..

250 Commits

Author SHA1 Message Date
zhayujie
db85b9808e feat(cli): add cow update 2026-03-28 18:58:42 +08:00
zhayujie
df5bae37bc feat: add MiniMax-M2.7 and glm-5-turbo in web console 2026-03-28 18:48:11 +08:00
zhayujie
acc23b6051 feat: optimize agent prompt and fix skill source load 2026-03-28 18:37:07 +08:00
zhayujie
61f2741afc feat: organize skill source field 2026-03-28 17:41:40 +08:00
zhayujie
4dd7ea886a feat(cli): cli options in web console 2026-03-28 16:26:41 +08:00
zhayujie
1e8959fbcf fix: optimize repo clone in run.sh 2026-03-28 15:08:57 +08:00
zhayujie
48729678cf Merge branch 'master' into feat-cow-cli 2026-03-28 14:47:20 +08:00
zhayujie
0684becaa7 fix(cli): register skill when installing 2026-03-28 14:42:18 +08:00
zhayujie
db16bdf8cb fix(cli): add security hardening for skill install and process management 2026-03-27 17:59:15 +08:00
zhayujie
f890318ed9 fix: strip leading/trailing whitespace from agent response 2026-03-26 18:13:39 +08:00
zhayujie
158510cbbe feat(cli): imporve cow cli and skill hub integration 2026-03-26 16:49:42 +08:00
zhayujie
ce90cf7aa8 fix: weixin cdn upload retry 2026-03-26 10:20:29 +08:00
zhayujie
a3a3d006eb Merge pull request #2723 from Xiaozhou345/Xiaozhou345-fix-readme-spacing
优化 README 中的中英文排版空格
2026-03-26 10:14:27 +08:00
zhayujie
8fd029a4a1 feat(cli): support cow cli 2026-03-26 10:08:51 +08:00
Xiaozhou345
2e1b52c1e5 优化 README 中的中英文排版空格
按照中文技术文档规范,在文件名和中文之间增加了空格,提升可读性。
2026-03-25 21:26:01 +08:00
zhayujie
3eb8348708 fix: docker volume permission issue and clean up unused dependencies 2026-03-25 01:25:34 +08:00
zhayujie
393f0c007c fix: context loss after trim 2026-03-24 20:49:28 +08:00
zhayujie
c062ca8c66 Merge pull request #2720 from 6vision/fix/deepseek-docs
Docs: update
2026-03-24 00:25:17 +08:00
6vision
76dcb25103 docs(deepseek): update model descriptions to V3.2 with thinking/non-thinking mode
Made-with: Cursor
2026-03-24 00:05:39 +08:00
6vision
c5b4f236db docs(deepseek): remove migration notes from zh and en docs
Made-with: Cursor
2026-03-24 00:05:39 +08:00
zhayujie
0974c940a8 Merge pull request #2719 from 6vision/feat/deepseek-bot
feat: add independent DeepSeek bot module with dedicated config
2026-03-23 22:42:58 +08:00
6vision
cffa20d37e docs(deepseek): remove migration notes to reduce user cognitive load
Made-with: Cursor
2026-03-23 22:39:15 +08:00
6vision
ef009edd29 docs(deepseek): update config guides for independent DeepSeek module
Update DeepSeek docs (zh/en/ja) and README to reflect the new dedicated deepseek_api_key / deepseek_api_base config fields, with backward compatibility notes.

Made-with: Cursor
2026-03-23 21:43:51 +08:00
zhayujie
3ca52b118d fix(weixin): qrcode url log 2026-03-23 21:33:53 +08:00
zhayujie
13f5fde4fb fix: rebuild system prompt from scratch on every turn 2026-03-23 21:27:44 +08:00
6vision
f512b55ec2 feat(deepseek): add independent DeepSeek bot module with dedicated config
Separate DeepSeek from ChatGPTBot into its own module (models/deepseek/) with dedicated deepseek_api_key and deepseek_api_base config fields, avoiding config conflicts when switching between providers. Backward compatible with old users who configured DeepSeek via open_ai_api_key/open_ai_api_base through automatic fallback.

Made-with: Cursor
2026-03-23 21:23:35 +08:00
zhayujie
22b8ca0095 feat: optimize vision image compression 2026-03-23 21:18:04 +08:00
zhayujie
baf66a103d fix(weixin): preserve original filename for received files 2026-03-23 01:18:02 +08:00
zhayujie
45faa9c1ff fix(wexin): resolve image/file send and receive failures 2026-03-23 00:13:41 +08:00
zhayujie
304381a88d fix: hide breadcrumb on mobile for better space utilization 2026-03-22 23:36:34 +08:00
zhayujie
fc9f54dbc8 feat(weixin): optimize login qrcode generate 2026-03-22 23:04:50 +08:00
zhayujie
7199dc187f fix: default gemini model 2026-03-22 22:52:37 +08:00
zhayujie
e9ae066d53 Merge pull request #2716 from cowagent/fix-gemini-model-attribute
fix: add missing model property to GoogleGeminiBot
2026-03-22 22:49:00 +08:00
cowagent
d71ae406ff fix: add missing model property to GoogleGeminiBot
api_key and api_base were refactored to @property but model was not
migrated, causing AttributeError: 'GoogleGeminiBot' object has no
attribute 'model' when using any Gemini model.
2026-03-22 22:43:26 +08:00
zhayujie
f3216904b3 feat(weixin): optimize weixin login qrcode 2026-03-22 21:34:47 +08:00
zhayujie
5958b69ec9 feat: release 2.0.4 2026-03-22 20:49:41 +08:00
zhayujie
7d4e2cb39a docs: update comments 2026-03-22 19:07:19 +08:00
zhayujie
a483ec0cea feat: optimize weixin channel qr code generate 2026-03-22 18:20:10 +08:00
zhayujie
c1421e0874 feat: support weixin channel in scripts 2026-03-22 16:29:12 +08:00
zhayujie
ce89869c3c feat: support weixin channel 2026-03-22 15:52:13 +08:00
zhayujie
b8b57e34ff fix: auto-repair messages 2026-03-21 14:20:22 +08:00
zhayujie
bc7f627253 fix(wecom_bot): compat with old websocket-client 2026-03-21 14:03:17 +08:00
zhayujie
652156e398 feat: make run.sh executable 2026-03-20 17:56:10 +08:00
zhayujie
9febb071c6 fix: run.sh get pid bug 2026-03-20 17:51:04 +08:00
zhayujie
7d0e1568ac fix: feishu msg and log encoding 2026-03-19 17:07:39 +08:00
zhayujie
b4e711f411 feat: add request header 2026-03-19 17:06:05 +08:00
zhayujie
1b5be1b981 fix: remove feishu_bot_name in run.sh 2026-03-19 14:55:12 +08:00
zhayujie
49d8707c58 refactor: simplify run.sh by extracting shared logic and eliminating duplication 2026-03-19 11:07:16 +08:00
zhayujie
9192f6f7f7 feat: add MiniMax-M2.7 and glm-5-turbo 2026-03-19 10:46:13 +08:00
zhayujie
05022e3745 fix: add log 2026-03-18 23:09:27 +08:00
zhayujie
5356e9ddeb docs: adjust docs order 2026-03-18 21:55:09 +08:00
zhayujie
52acf76e2c docs: update jp docs 2026-03-18 21:01:02 +08:00
zhayujie
40cdbd3b45 Merge pull request #2710 from eltociear/add-ja-doc
docs: add Japanese documents
2026-03-18 19:28:04 +08:00
Ikko Ashimine
5487c0befe docs: add Japanese documents 2026-03-18 19:13:39 +09:00
zhayujie
8bb16c48c0 docs: update install cmd 2026-03-18 16:11:35 +08:00
zhayujie
c6384363f9 feat: workspace volume in docker deploy 2026-03-18 16:03:03 +08:00
zhayujie
8993e8ad3e feat: release 2.0.3 2026-03-18 15:40:49 +08:00
zhayujie
289989d9f7 feat: release 2.0.3 2026-03-18 15:10:21 +08:00
zhayujie
dc2ae0e6f1 feat: support gpt-5.4-mini and gpt-5.4-nano 2026-03-18 14:55:29 +08:00
zhayujie
9c966c152d feat: enhance AGENT.md update prompts to encourage proactive evolution 2026-03-18 12:10:45 +08:00
zhayujie
4efae41048 feat: support coding plan 2026-03-18 11:59:22 +08:00
zhayujie
b8437032e9 fix: optimize image recognition prompts 2026-03-18 10:10:23 +08:00
zhayujie
2d339ca81b Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-03-17 23:03:05 +08:00
zhayujie
d53abc9696 docs: update README.md 2026-03-17 23:02:41 +08:00
zhayujie
446c886d38 Merge pull request #2706 from zhayujie/feat-web-files
feat: support files upload in web console and office parsing
2026-03-17 21:22:38 +08:00
zhayujie
30c6d9b5ae feat: support file and image upload in web console, add office docs parsing in read tool 2026-03-17 21:21:03 +08:00
zhayujie
5e42996b36 fix: guide LLM to use matching skill when tool not found 2026-03-17 18:34:09 +08:00
zhayujie
ceca7b85bf Merge pull request #2705 from zhayujie/feat-qq-channel
feat: add qq channel
2026-03-17 17:26:39 +08:00
zhayujie
a4d54f58c8 feat: complete the QQ channel and supplement the docs 2026-03-17 17:25:36 +08:00
zhayujie
005a0e1bad feat: add qq channel 2026-03-17 15:43:04 +08:00
zhayujie
46d97fd57d feat: channel config set to env 2026-03-17 11:36:20 +08:00
zhayujie
72a26b6353 fix: scheduler auto clean 2026-03-17 11:29:21 +08:00
zhayujie
89a4033fbf fix: web console bot_type 2026-03-17 10:47:41 +08:00
zhayujie
39a5dc64bd Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-03-16 19:07:54 +08:00
zhayujie
d4bdd9b1b7 docs: update README.md for wecom_bot channel 2026-03-16 19:07:08 +08:00
zhayujie
2f5ba87280 Merge pull request #2698 from zhayujie/feat-wecom-bot
feat: wecom_bot channel
2026-03-16 19:04:52 +08:00
zhayujie
8b45d6c750 docs: wecom_bot integration docs 2026-03-16 19:03:18 +08:00
zhayujie
4ecd4df2d4 feat: web console support wecom_bot config 2026-03-16 17:56:59 +08:00
zhayujie
a42f31fe52 feat: support wecom_bot stream card 2026-03-16 17:46:05 +08:00
zhayujie
d4480b695e feat(channel): add wecom_bot channel 2026-03-16 14:39:15 +08:00
zhayujie
c4b5f7fbae refactor: remove unavailable channels 2026-03-16 11:05:45 +08:00
zhayujie
ba915f2cc0 feat: add gemini-3.1-flash-lite-preview and gpt-5.4 2026-03-15 22:06:12 +08:00
zhayujie
4b91140f31 fix: optimize msg receive 2026-03-12 20:49:36 +08:00
zhayujie
9879878dd0 fix: concurrency issue in session 2026-03-12 17:08:09 +08:00
zhayujie
d78105d57c fix: tool call match 2026-03-12 17:05:27 +08:00
zhayujie
153c9e3565 fix(memory): remove useless prompt 2026-03-12 15:29:58 +08:00
zhayujie
c11623596d fix(memory): prevent context memory loss by improving trim strategy 2026-03-12 15:25:46 +08:00
zhayujie
e791a77f77 fix: strengthen bootstrap flow 2026-03-12 12:13:05 +08:00
zhayujie
b641bffb2c fix(feishu): remove bot_name dependency for group chat 2026-03-12 11:30:42 +08:00
zhayujie
ee0c47ac1e feat: file send prompt 2026-03-12 00:11:34 +08:00
zhayujie
eba90e9343 fix: workspace bootstrap 2026-03-11 23:35:42 +08:00
zhayujie
d8374d0fa5 fix: web_fetch encoding 2026-03-11 19:42:37 +08:00
zhayujie
fa61744c6d feat(web_fetch): support downloading and parsing remote document files (PDF, Word, Excel, PPT) 2026-03-11 17:47:15 +08:00
zhayujie
4fec55cc01 feat: web_featch tool support remote file url 2026-03-11 17:16:39 +08:00
zhayujie
1767413712 fix: increase minimax max_tokens 2026-03-11 15:31:35 +08:00
zhayujie
734c8fa84f fix: optimize skill prompt 2026-03-11 12:40:37 +08:00
zhayujie
9a8d422554 feat: package skill install 2026-03-11 12:18:36 +08:00
zhayujie
b21e945c76 feat: optimize bootstrap flow 2026-03-11 11:27:08 +08:00
zhayujie
a02bf1ea09 Merge pull request #2693 from 6vision/fix/bot-type-and-web-config
fix: rename zhipu bot_type, persist bot_type in web config, fix re.syb escape error
2026-03-11 10:24:19 +08:00
zhayujie
eda82bac92 fix: gemini tool call bug 2026-03-11 02:04:09 +08:00
zhayujie
e8d4f7dc4f fix: remove useless file 2026-03-10 22:56:00 +08:00
6vision
c4a93b7789 fix: rename zhipu bot_type, persist bot_type in web config, fix re.sub escape error
- Rename ZHIPU_AI bot type from glm-4 to zhipu to avoid confusion with model names

- Add bot_type persistence in web config to fix provider dropdown resetting on refresh

- Change OpenAI provider key to chatGPT to match bot_factory routing

- Add DEEPSEEK constant and route it to ChatGPTBot (OpenAI-compatible API)

- Keep backward compatibility for legacy bot_type glm-4 in bot_factory

- Fix re.sub bad escape error on Windows paths by using lambda replacement

- Remove unused pydantic import in minimax_bot.py

Made-with: Cursor
2026-03-10 21:34:24 +08:00
zhayujie
c3f9925097 fix: remove injected max-steps prompt from persisted conversation history 2026-03-10 20:08:59 +08:00
zhayujie
2a0cf7511a Merge pull request #2692 from 6vision/master
update:Adjust bot_type resolution priority in Agent mode
2026-03-10 15:17:22 +08:00
6vision
d0a70d3339 update:Adjust bot_type resolution priority in Agent mode 2026-03-10 15:14:01 +08:00
zhayujie
f37e4675dd Merge pull request #2691 from Weikjssss/fix-bot-type-conf
fix: pass bot_type in agent mode
2026-03-10 15:00:04 +08:00
zhayujie
4e32f67eeb fix: validate tool_call_id pairing #2690 2026-03-10 14:52:07 +08:00
Weikjssss
36d54cab52 fix: pass bot_type in agent mode 2026-03-10 14:28:39 +08:00
zhayujie
9d8df10dcf feat: clarify send tool is local-only 2026-03-10 12:10:10 +08:00
zhayujie
45ea88e070 Merge pull request #2689 from cowagent/fix/openai-compat-complete
fix: complete openai_compat migration across all model bots (openai>=1.0 compatibility)
2026-03-10 10:10:58 +08:00
cowagent
d5d0b947f5 fix: complete openai_compat migration across all model bots
Replace all direct openai.error.* usages with the openai_compat
compatibility layer to support openai>=1.0.

Affected files:
- models/chatgpt/chat_gpt_bot.py: fix isinstance checks (RateLimitError, Timeout, APIError, APIConnectionError)
- models/openai/open_ai_bot.py: replace import + fix isinstance checks
- models/ali/ali_qwen_bot.py: replace import + fix isinstance checks
- models/modelscope/modelscope_bot.py: remove unused openai.error import

The openai_compat layer (models/openai/openai_compat.py) already
handles both openai<1.0 and openai>=1.0 gracefully. This completes
the migration started in the existing PR #2688.
2026-03-10 10:06:04 +08:00
zhayujie
f775f1f11e Merge pull request #2688 from JasonOA888/fix/openai-compat
fix: use openai_compat layer for error handling (openai>=1.0 compatibility)
2026-03-10 10:02:41 +08:00
JasonOA888
f1e888f3de fix: use openai_compat layer for error handling
The code was directly importing openai.error which fails with openai>=1.0.
The project already has an openai_compat.py compatibility layer that handles
both old (<1.0) and new (>=1.0) OpenAI SDK versions.

This commit updates chat_gpt_bot.py to use the compatibility layer.

Related: #2687
2026-03-10 00:33:45 +08:00
zhayujie
71c8436e90 fix: skill download to temp dir 2026-03-09 18:43:28 +08:00
zhayujie
08c69f5e9b fix: clean existing skill directory before remote install to ensure full overwrite 2026-03-09 17:23:09 +08:00
zhayujie
a50fafaca2 refactor: convert image vision from skill to native tool 2026-03-09 16:01:56 +08:00
zhayujie
3c6781d240 refactor: inline skill-creator reference files into SKILL.md 2026-03-09 12:02:52 +08:00
zhayujie
3b8b5625f8 feat: add image vision provider 2026-03-09 11:37:45 +08:00
zhayujie
6be2034110 feat: add fallback embedding provider 2026-03-09 11:03:31 +08:00
zhayujie
924dc79f00 perf: lazy import to avoid 4-10s startup delay 2026-03-09 10:21:58 +08:00
zhayujie
ccb9030d3c refactor: convert web-fetch from skill to native tool 2026-03-09 10:13:48 +08:00
zhayujie
8623287ac1 docs: update memory system docs 2026-03-08 22:06:28 +08:00
zhayujie
022c13f3a4 feat: upgrade memory flush system
- Use LLM to summarize discarded context into concise daily memory entries
- Batch trim to half when exceeding max_turns/max_tokens, reducing flush frequency
- Run summarization asynchronously in background thread, no blocking on replies
- Add daily scheduled flush (23:55) as fallback for low-activity days
- Sync trimmed messages back to agent to keep context state consistent
2026-03-08 21:56:12 +08:00
zhayujie
0687916e7f fix: Safari IME enter key triggering message send
Made-with: Cursor
2026-03-08 13:21:31 +08:00
zhayujie
bb868b83ba feat: add chat history query 2026-03-08 13:03:27 +08:00
zhayujie
24298130b9 fix: minimax tool_id missing 2026-03-06 18:42:03 +08:00
zhayujie
6e5ee92ebd docs: add gpt-5.4 2026-03-06 12:25:50 +08:00
zhayujie
5b91fe04aa fix: send tool process url 2026-03-06 12:22:22 +08:00
zhayujie
1623deb3ee feat: support gpt-5.4 2026-03-06 12:04:40 +08:00
zhayujie
4a16e05b7a fix: rebuild skills when installing 2026-03-05 21:11:34 +08:00
zhayujie
f1c04bc60d feat: improve channel connection stability 2026-03-05 15:55:16 +08:00
zhayujie
84c6f31c76 fix: update agent skill metadata 2026-03-03 18:16:42 +08:00
zhayujie
9d528190bf feat: add skill category 2026-03-03 16:06:37 +08:00
zhayujie
0f23b209ad fix: adjust the context of restart loading 2026-03-03 11:38:14 +08:00
zhayujie
63d9325900 Merge pull request #2683 from pelioo/master
更新.gitignore文件添加python目录忽略规则
2026-03-01 19:41:27 +08:00
peli
f342097f81 Merge remote-tracking branch 'upstream/master' 2026-03-01 00:24:14 +08:00
zhayujie
b4806c4366 fix: model provider config 2026-02-28 18:35:04 +08:00
zhayujie
ff37d8a577 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-28 18:10:55 +08:00
zhayujie
a773eb7893 fix: filter history to one user and one assistant per turn 2026-02-28 18:09:02 +08:00
zhayujie
7c67513d24 fix: convert bash-style $VAR to %VAR% on Windows 2026-02-28 18:02:06 +08:00
zhayujie
6ed85029c5 fix: agent skills 2026-02-28 16:46:49 +08:00
zhayujie
e9c57ddf4d fix: adjust default turns 2026-02-28 15:25:20 +08:00
zhayujie
a33ce97ed9 fix: restore only user/assistant text from history, strip tool calls
Made-with: Cursor
2026-02-28 15:14:56 +08:00
zhayujie
b788a3dd4e fix: incomplete historical session messages 2026-02-28 15:03:33 +08:00
zhayujie
fccfa92d7e docs: update channel docs 2026-02-28 14:50:55 +08:00
zhayujie
8705bf0a70 feat: update docs 2026-02-28 10:53:16 +08:00
peli
9318138af7 ```
build(env): 更新.gitignore文件添加python目录忽略规则

在.gitignore文件中新增了python目录的忽略配置,
避免将Python环境相关文件提交到版本控制系统中。
```
2026-02-27 23:49:35 +08:00
zhayujie
269fa7d2d5 feat: 2.0.2 en docs 2026-02-27 18:37:22 +08:00
zhayujie
e99837a8b9 feat: release 2.0.2 2026-02-27 18:04:00 +08:00
zhayujie
553861a2c4 docs: update README.md 2026-02-27 16:57:18 +08:00
zhayujie
628a85d1be docs: update README.md 2026-02-27 16:48:23 +08:00
zhayujie
2cb54514a4 Merge pull request #2681 from zhayujie/feat-docs
feat: docs update
2026-02-27 16:04:17 +08:00
zhayujie
6db22827f2 feat: docs update 2026-02-27 16:03:47 +08:00
zhayujie
4cc6d5426b Merge pull request #2680 from zhayujie/feat-web-config
feat: web console config
2026-02-27 14:40:44 +08:00
zhayujie
7d258b5202 feat(channels): add multi-channel management UI with real-time connect/disconnect
- Web console Channels page: display active channels as config cards, support
  save/connect/disconnect with real-time start/stop of channel processes
- Custom dropdown for channel selection (consistent with model selector style),
  custom confirmation dialog for disconnect
- Fix channel stop: use sys.modules['__main__'] to access live ChannelManager
- Fix web request pending: move stop logic outside lock, set daemon_threads=True
- Fix reconnect: new asyncio event loop per startup, ctypes thread interrupt,
  5s grace period before re-establishing remote connection
- Filter stale offline messages (>60s) pushed after reconnect
2026-02-27 14:39:40 +08:00
zhayujie
c8d19ee0bc Merge pull request #2679 from zhayujie/feat-docs
docs: init docs
2026-02-27 12:14:37 +08:00
zhayujie
d891312032 docs: init docs 2026-02-27 12:10:16 +08:00
zhayujie
5edbf4ce32 feat: model and agent config in web console 2026-02-26 21:01:37 +08:00
zhayujie
3ddbdd713d Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-26 18:57:43 +08:00
zhayujie
9ba107b511 Merge branch 'feat-multi-channel' 2026-02-26 18:57:19 +08:00
zhayujie
c9adddb76a fix: pass channel_type correctly in multi-channel mode 2026-02-26 18:57:08 +08:00
zhayujie
f0a12d5ff5 Merge pull request #2678 from zhayujie/feat-multi-channel
feat: support multi-channel
2026-02-26 18:34:48 +08:00
zhayujie
7cce224499 feat: support multi-channel 2026-02-26 18:34:08 +08:00
zhayujie
97397ca585 Merge pull request #2674 from haosenwang1018/fix/bare-excepts
fix: replace 29 bare except clauses with except Exception
2026-02-26 12:11:49 +08:00
zhayujie
f2fbc602a8 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-26 10:45:01 +08:00
zhayujie
925d728a86 fix: replace upsert syntax to support SQLite lower version 2026-02-26 10:44:04 +08:00
zhayujie
f5f229871b Merge pull request #2676 from zhayujie/feat-multi-channel
feat: improve web console and conversation store
2026-02-26 10:37:03 +08:00
zhayujie
9917552b4b fix: improve web UI stability and conversation history restore
- Fix dark mode FOUC: apply theme in <head> before first paint, defer
  transition-colors to post-init to avoid animated flash on load
- Fix Safari IME Enter bug: defer compositionend reset via setTimeout(0)
- Fix history scroll: use requestAnimationFrame before scrollChatToBottom
- Limit restore turns to min(6, max_turns//3) on restart
- Fix load_messages cutoff to start at turn boundary, preventing orphaned
  tool_use/tool_result pairs from being sent to the LLM
- Merge all assistant messages within one user turn into a single bubble;
  render tool_calls in history using same CSS as live SSE view
- Handle empty choices list in stream chunks
2026-02-26 10:35:20 +08:00
haosenwang1018
adca89b973 fix: replace bare except clauses with except Exception
Bare `except:` catches BaseException including KeyboardInterrupt and
SystemExit. Replaced 29 instances with `except Exception:`.
2026-02-25 11:49:19 +00:00
zhayujie
29bfbecdc9 feat: persistent storage of conversation history 2026-02-25 18:01:39 +08:00
zhayujie
1a7a8c98d9 docs: add scam warning disclaimer 2026-02-25 01:34:16 +08:00
zhayujie
cddb38ac3d Merge pull request #2673 from zhayujie/feat-web-console
feat: web console
2026-02-24 00:06:29 +08:00
zhayujie
394853c0fb feat: web console module display 2026-02-24 00:04:17 +08:00
zhayujie
c0702c8b36 feat: web channel stream chat 2026-02-23 22:19:50 +08:00
zhayujie
d610608391 feat: add cloud host config 2026-02-23 15:06:31 +08:00
zhayujie
9082eec91d feat: dark mode is used by default 2026-02-23 14:57:02 +08:00
zhayujie
f1a1413b5f feat: web console upgrade 2026-02-21 17:56:31 +08:00
zhayujie
c1e7f9af9b Merge pull request #2672 from zhayujie/feat-config-update
feat: cloud config update
2026-02-21 11:34:05 +08:00
zhayujie
1c71c4e38b feat: agent chat service 2026-02-21 00:39:36 +08:00
zhayujie
5e3eccb3f6 feat: support memory service 2026-02-20 23:44:05 +08:00
zhayujie
e1dc037eb9 feat: cloud skills manage 2026-02-20 23:23:04 +08:00
zhayujie
97e9b4c801 Merge branch 'master' into feat-config-update 2026-02-20 18:58:21 +08:00
zhayujie
52d7cad735 feat: support gemini-3.1-pro-preview and claude-4.6-sonnet 2026-02-20 12:14:59 +08:00
zhayujie
c0b1d270ba Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-19 14:18:39 +08:00
zhayujie
e59a2892e4 feat: support qwen3.5-plus 2026-02-19 14:18:16 +08:00
zhayujie
5fa0376a49 Merge pull request #2670 from SgtPepper114/fix/gemini-dingtalk-image-inline
fix(gemini): 修复钉钉图片标记未转多模态导致的识图失效
2026-02-19 13:57:04 +08:00
SgtPepper114
05a33042c8 fix(gemini): support dingtalk image markers as multimodal input
- parse [图片: path] markers in text and convert to Gemini inlineData parts

- unify reply path via call_with_tools to reuse multimodal conversion

- keep legacy safety behavior (BLOCK_NONE) and restore safety ratings logging on empty response

- add multimodal request image-part count log for debugging
2026-02-16 13:26:57 +00:00
zhayujie
ce58f23cbc feat: dashscope model name 2026-02-16 20:11:38 +08:00
zhayujie
b6fc9fa370 fix: run script dependency issues 2026-02-15 00:02:50 +08:00
zhayujie
00ae38faae docs: update models in README 2026-02-14 17:36:36 +08:00
zhayujie
ab28ee58ab feat: add doubao-2.0-code model and update README 2026-02-14 16:49:44 +08:00
zhayujie
48db538a2e feat: support Minimax-M2.5, glm-5, kimi-k2.5 2026-02-14 15:27:44 +08:00
zhayujie
46945942e1 feat: support channel start in sub thread 2026-02-13 12:38:52 +08:00
zhayujie
a24b26a1ef Merge pull request #2667 from cowagent/fix-wechatcom-image-support
fix: 支持企业微信图片消息识别功能
2026-02-12 16:44:18 +08:00
zhayujie
6f8421cdd5 fix: 支持企业微信图片消息识别功能
- 在 ChatGPTBot 中添加 ContextType.IMAGE 处理分支
- 新增 reply_image() 方法,支持 OpenAI Vision API
- 自动 Base64 编码图片并检测格式
- 自动清理临时文件

修复 #2625
2026-02-12 12:00:24 +08:00
zhayujie
284cd9bca9 Merge pull request #2666 from cowagent/fix-model-type-validation
fix: handle non-string model_type to prevent AttributeError
2026-02-10 11:31:45 +08:00
cowagent
23fd6b8d2b fix: handle non-string model_type to prevent AttributeError
When numeric model names (e.g., '1') are used with vLLM and configured
in YAML without quotes, they are parsed as integers. This causes
AttributeError when calling startswith() method.

Changes:
- Add type checking for model_type
- Convert non-string model_type to string with warning log
- Prevents crash when using custom numeric model names

Fixes #2664
2026-02-10 11:07:10 +08:00
zhayujie
4f0ea5d756 feat: make web search a built-in tool 2026-02-09 11:37:11 +08:00
zhayujie
6c218331b1 fix: improve skill system prompts and simplify tool descriptions
- Simplify skill-creator installation flow
- Refine skill selection prompt for better matching
- Add parameter alias and env variable hints for tools
- Skip linkai-agent when unconfigured
- Create skills/ dir in workspace on init
2026-02-08 18:59:59 +08:00
zhayujie
cea7fb7490 fix: add intelligent context cleanup #2663 2026-02-07 20:42:41 +08:00
zhayujie
8acf2dbdfe fix: chat context overflow #2663 2026-02-07 20:36:24 +08:00
zhayujie
0542700f90 fix: issues with empty tool calls and handling excessively long tool results 2026-02-07 20:25:05 +08:00
zhayujie
5264f7ce18 fix: getuid not found in windows 2026-02-07 11:17:58 +08:00
zhayujie
051ffd78a3 fix: windows path and encoding adaptation 2026-02-06 18:37:05 +08:00
zhayujie
bea95d4fae Merge pull request #2661 from cowagent/feat-add-claude-opus-4-6
feat: 添加 Claude Opus 4.6 模型支持
2026-02-06 15:09:49 +08:00
cowagent
fdf7bc312f feat: 添加 Claude Opus 4.6 模型支持
- 在 common/const.py 中添加 CLAUDE_4_6_OPUS 常量
- 将 claude-opus-4-6 添加到 MODEL_LIST
- 在 README.md 中更新 Agent 推荐模型列表
- 在 Claude 配置说明中添加 claude-opus-4-6 支持

Claude Opus 4.6 是 Anthropic 于 2026年2月5日发布的最新模型,
具有更强的规划能力和代码能力,适合作为 Agent 推荐模型。
2026-02-06 15:07:43 +08:00
vision
5b094e1097 Merge pull request #2660 from cowagent/fix-zhipuai-api-base-support
fix: 支持智谱AI自定义API base URL配置
2026-02-05 19:18:49 +08:00
cowagent
9ad3968084 fix: 支持智谱AI自定义API base URL配置
- 修复 ZhipuAiClient 初始化时未传入 base_url 参数的问题
- 使配置文件中的 zhipu_ai_api_base 配置项生效
- 支持智谱国际版(z.ai)等自定义API端点
- 同时修复对话和图片生成功能
- 添加日志输出便于确认使用的API地址

Fixes #2659
2026-02-05 19:06:46 +08:00
zhayujie
3958b6aae1 Merge pull request #2657 from cowagent/fix-missing-runtime-info-parameter
fix: 补充缺失的 runtime_info 参数传递
2026-02-04 22:51:53 +08:00
cowagent
eaa413caf0 fix: 补充缺失的 runtime_info 参数传递
问题:
PR #2655 已合并,但遗漏了关键的参数传递环节。runtime_info 在 agent_initializer.py 中创建并传递给 create_agent(),但 agent_bridge.py 的 create_agent() 方法中没有将其传递给 Agent 实例,导致动态时间更新功能无法生效。

影响:
- Agent 实例的 self.runtime_info 为 None
- get_full_system_prompt() 无法检测到动态时间函数
- 时间戳仍然是静态的,不会实时更新

修复:
在 agent_bridge.py 第 236 行添加:
runtime_info=kwargs.get("runtime_info")

这确保了完整的参数传递链路:
agent_initializer → agent_bridge.create_agent → Agent.__init__

---

*来自 [CowAgent](https://github.com/zhayujie/chatgpt-on-wechat) 项目的 AI Agent*
2026-02-04 22:49:54 +08:00
zhayujie
9095225b5b Merge pull request #2656 from 6vision/master
Update: improve script interaction and configuration
2026-02-04 22:46:02 +08:00
zhayujie
c529f86dbc Merge pull request #2655 from cowagent/fix-runtime-timestamp-update
fix: 动态更新系统提示词中的运行时信息(时间戳)
2026-02-04 22:38:51 +08:00
cowagent
e4fcfa356a refactor: 改用动态函数实现运行时信息更新(更健壮的方案)
改进点:
1. builder.py: _build_runtime_section() 支持 callable 动态时间函数
2. agent_initializer.py: 传入 get_current_time 函数而非静态时间值
3. agent.py: _rebuild_runtime_section() 动态调用时间函数并重建该部分

优势:
- 解耦模板:不依赖具体的提示词格式
- 健壮性:提示词模板改变不会导致功能失效
- 向后兼容:保留对静态时间的支持
- 性能优化:只在需要时才计算时间

相比之前的正则匹配方案,这个方案更加优雅和可维护。
2026-02-04 22:37:19 +08:00
vision
8218cff7c1 Merge branch 'zhayujie:master' into master 2026-02-04 22:32:20 +08:00
6vision
6949bbcf39 update: Improve script interaction and configuration 2026-02-04 22:31:40 +08:00
cowagent
480c60c0a7 fix: 动态更新系统提示词中的运行时信息(时间戳)
问题:
- system_prompt 在 Agent 初始化时固定,导致模型获取的时间信息过时
- 长时间运行的会话中,模型对时间判断不准确

解决方案:
- 在 get_full_system_prompt() 中添加动态更新逻辑
- 每次获取系统提示词时,使用正则表达式替换运行时信息中的时间戳
- 保持其他运行时信息(模型、工作空间等)不变

测试:
- 创建测试脚本验证时间动态更新功能
- 等待3秒后时间正确更新(22:19:45 -> 22:19:48)
2026-02-04 22:27:24 +08:00
zhayujie
eec10cb5db fix: claude remove toolname 2026-02-04 22:15:10 +08:00
zhayujie
02c83d8689 docs: update agent.md 2026-02-04 21:42:52 +08:00
zhayujie
72b1cacea1 fix: hiding the thought process 2026-02-04 19:36:01 +08:00
zhayujie
c72cda3386 fix: minimax reasoning content optimization 2026-02-04 19:26:36 +08:00
zhayujie
867442155e fix: lark connection issue 2026-02-04 17:05:30 +08:00
zhayujie
229b14b6fc fix: feishu cert error 2026-02-04 16:15:38 +08:00
zhayujie
158c87ab8b fix: openai function call 2026-02-04 15:42:43 +08:00
zhayujie
cb303e6109 fix: add decision round log 2026-02-03 21:27:30 +08:00
saboteur7
a77a8741b5 fix: memory loss issue caused by scheduler 2026-02-03 20:45:22 +08:00
zhayujie
3d63459c25 docs: update README.md 2026-02-03 15:44:00 +08:00
saboteur7
ce63de3c58 feat: release 2.0.0 2026-02-03 14:48:30 +08:00
saboteur7
4b3b1219b5 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-03 12:20:04 +08:00
saboteur7
73b069a76c docs: update 2.0 README.md 2026-02-03 12:19:36 +08:00
Saboteur7
101cf8d108 Merge pull request #2653 from 6vision/deploy-script
feat: enhance one-click deployment script with full lifecycle management
2026-02-03 03:18:49 +08:00
saboteur7
2e926dfb6e fix: python 3.8 compatibility issues 2026-02-03 03:17:11 +08:00
saboteur7
501866d12a feat: optimize document and model usage 2026-02-03 02:58:15 +08:00
6vision
39bcb0869f feat: enhance one-click deployment script with full lifecycle management 2026-02-03 02:56:46 +08:00
saboteur7
a7b99cde4e Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-03 01:18:17 +08:00
saboteur7
60abcd92a3 feat: update README.md and solving Python compatibility issues 2026-02-03 01:17:25 +08:00
zhayujie
cdd36e7052 docs: update README.md 2026-02-03 00:48:03 +08:00
saboteur7
c6ac175ce4 docs: update README.md 2026-02-03 00:43:42 +08:00
zhayujie
46bcd87c23 feat: support minimax M2 models 2026-02-02 23:36:23 +08:00
zhayujie
ab74be8e33 feat: add qwen models tool call 2026-02-02 23:08:24 +08:00
zhayujie
d8298b3eab fix: support glm-4.7 2026-02-02 22:43:08 +08:00
zhayujie
50e60e6d05 fix: bug fixes 2026-02-02 22:22:10 +08:00
zhayujie
5d02acbf37 config: add config template 2026-02-02 14:25:34 +08:00
zhayujie
8901d91f96 feat: startup log optimization 2026-02-02 12:25:47 +08:00
zhayujie
b55021bb3d feat: system Initialization log 2026-02-02 12:18:57 +08:00
zhayujie
0ef51b85e6 Merge branch 'feat-cow-agent' 2026-02-02 12:03:55 +08:00
zhayujie
c77566cc02 fix: adjust the maximum step size 2026-02-02 12:03:16 +08:00
zhayujie
c1bcedfb51 Merge pull request #2652 from zhayujie/feat-cow-agent
feat: cow super agent
2026-02-02 11:59:45 +08:00
zhayujie
08b592816b Merge pull request #2651 from zhayujie/feat-cow-agent
fix: optimize suggestion words and retries
2026-02-01 14:11:53 +08:00
zhayujie
8ef788e799 Merge pull request #2650 from zhayujie/feat-cow-agent
feat: cow agent
2026-02-01 13:14:00 +08:00
Saboteur7
3ce57ef851 Merge pull request #2648 from zhayujie/feat-cow-agent
feat: cow agent core
2026-01-31 13:14:05 +08:00
341 changed files with 31409 additions and 12175 deletions

View File

@@ -79,8 +79,6 @@ body:
description: |
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
options:
- wx(个人微信, itchat)
- wxy(个人微信, wechaty)
- wechatmp(公众号, 订阅号)
- wechatmp_service(公众号, 服务号)
- terminal

11
.gitignore vendored
View File

@@ -3,16 +3,15 @@
.vscode
.venv
.vs
.wechaty/
__pycache__/
venv*
*.pyc
python
config.json
QR.png
nohup.out
tmp
plugins.json
itchat.pkl
*.log
logs/
workspace
@@ -34,7 +33,15 @@ plugins/banwords/lib/__pycache__
!plugins/keyword
!plugins/linkai
!plugins/agent
!plugins/cow_cli
client_config.json
ref/
.cursor/
local/
node_modules/
# cow cli
dist/
build/
*.egg-info/
.cow.pid

798
README.md

File diff suppressed because it is too large Load Diff

3
agent/chat/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from agent.chat.service import ChatService
__all__ = ["ChatService"]

264
agent/chat/service.py Normal file
View File

@@ -0,0 +1,264 @@
"""
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
Translates agent events (message_update, message_end, tool_execution_end, etc.)
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
"""
import time
from typing import Callable, Optional
from common.log import logger
class ChatService:
"""
High-level service that runs an Agent for a given query and streams
the results as CHAT protocol chunks via a callback.
Usage:
svc = ChatService(agent_bridge)
svc.run(query, session_id, send_chunk_fn)
"""
def __init__(self, agent_bridge):
"""
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
"""
self.agent_bridge = agent_bridge
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
channel_type: str = ""):
"""
Run the agent for *query* and stream results back via *send_chunk_fn*.
The method blocks until the agent finishes. After it returns the SDK
will automatically send the final (streaming=false) message.
:param query: user query text
:param session_id: session identifier for agent isolation
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
:param channel_type: source channel (e.g. "web", "feishu") for persistence
"""
agent = self.agent_bridge.get_agent(session_id=session_id)
if agent is None:
raise RuntimeError("Failed to initialise agent for the session")
# Pass context metadata to model for downstream API requests
if hasattr(agent, 'model'):
agent.model.channel_type = channel_type or ""
agent.model.session_id = session_id or ""
# State shared between the event callback and this method
state = _StreamState()
def on_event(event: dict):
"""Translate agent events into CHAT protocol chunks."""
event_type = event.get("type")
data = event.get("data", {})
if event_type == "message_update":
# Incremental text delta
delta = data.get("delta", "")
if delta:
send_chunk_fn({
"chunk_type": "content",
"delta": delta,
"segment_id": state.segment_id,
})
elif event_type == "message_end":
# A content segment finished.
tool_calls = data.get("tool_calls", [])
if tool_calls:
# After tool_calls are executed the next content will be
# a new segment; collect tool results until turn_end.
state.pending_tool_results = []
elif event_type == "tool_execution_start":
# Notify the client that a tool is about to run (with its input args)
tool_name = data.get("tool_name", "")
arguments = data.get("arguments", {})
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
tool_call_id = data.get("tool_call_id", tool_name)
state.pending_tool_arguments[tool_call_id] = arguments
send_chunk_fn({
"chunk_type": "tool_start",
"tool": tool_name,
"arguments": arguments,
})
elif event_type == "tool_execution_end":
tool_name = data.get("tool_name", "")
tool_call_id = data.get("tool_call_id", tool_name)
# Retrieve cached arguments from the matching tool_execution_start event
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
result = data.get("result", "")
status = data.get("status", "unknown")
execution_time = data.get("execution_time", 0)
elapsed_str = f"{execution_time:.2f}s"
# Serialise result to string if needed
if not isinstance(result, str):
import json
try:
result = json.dumps(result, ensure_ascii=False)
except Exception:
result = str(result)
tool_info = {
"name": tool_name,
"arguments": arguments,
"result": result,
"status": status,
"elapsed": elapsed_str,
}
if state.pending_tool_results is not None:
state.pending_tool_results.append(tool_info)
elif event_type == "turn_end":
has_tool_calls = data.get("has_tool_calls", False)
if has_tool_calls and state.pending_tool_results:
# Flush collected tool results as a single tool_calls chunk
send_chunk_fn({
"chunk_type": "tool_calls",
"tool_calls": state.pending_tool_results,
})
state.pending_tool_results = None
# Next content belongs to a new segment
state.segment_id += 1
# Run the agent with our event callback ---------------------------
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
from config import conf
max_context_turns = conf().get("agent_max_context_turns", 20)
# Get full system prompt with skills
full_system_prompt = agent.get_full_system_prompt()
# Create a copy of messages for this execution
with agent.messages_lock:
messages_copy = agent.messages.copy()
original_length = len(agent.messages)
from agent.protocol.agent_stream import AgentStreamExecutor
executor = AgentStreamExecutor(
agent=agent,
model=agent.model,
system_prompt=full_system_prompt,
tools=agent.tools,
max_turns=agent.max_steps,
on_event=on_event,
messages=messages_copy,
max_context_turns=max_context_turns,
)
try:
response = executor.run_stream(query)
except Exception:
# If executor cleared messages (context overflow), sync back
if len(executor.messages) == 0:
with agent.messages_lock:
agent.messages.clear()
logger.info("[ChatService] Cleared agent message history after executor recovery")
raise
# Sync executor messages back to agent (thread-safe).
# The executor may have trimmed context, making its list shorter than
# original_length. In that case we must replace entirely — just
# appending would leave stale pre-trim messages in agent.messages
# and cause the same trim to fire on every subsequent request.
with agent.messages_lock:
trimmed = len(executor.messages) < original_length
if trimmed:
# Context was trimmed: the executor appended the new user
# query *before* trimming, so the new messages (user +
# assistant + tools) sit at the tail of the trimmed list.
# We cannot simply slice at original_length (it exceeds the
# list length). Instead, count how many messages the
# executor added on top of the post-trim baseline.
#
# Timeline inside executor.run_stream:
# 1. messages had `original_length` items
# 2. append user query → original_length + 1
# 3. _trim_messages() → some smaller number (includes the
# user query because it belongs to the last turn)
# 4. LLM replies / tool calls appended
#
# The user query message is always the first message of the
# last turn (it cannot be trimmed away), so we locate it to
# find where "new" messages begin.
new_start = original_length # fallback
for idx in range(len(executor.messages) - 1, -1, -1):
msg = executor.messages[idx]
if msg.get("role") == "user":
content = msg.get("content", [])
is_user_query = False
if isinstance(content, list):
has_text = any(
isinstance(b, dict) and b.get("type") == "text"
for b in content
)
has_tool_result = any(
isinstance(b, dict) and b.get("type") == "tool_result"
for b in content
)
is_user_query = has_text and not has_tool_result
elif isinstance(content, str):
is_user_query = True
if is_user_query:
new_start = idx
break
new_messages = list(executor.messages[new_start:])
else:
new_messages = list(executor.messages[original_length:])
agent.messages = list(executor.messages)
# Persist new messages to SQLite so they survive restarts and
# can be queried via the HISTORY interface.
if new_messages:
self._persist_messages(session_id, list(new_messages), channel_type)
# Store executor reference for files_to_send access
agent.stream_executor = executor
# Execute post-process tools
agent._execute_post_process_tools()
logger.info(f"[ChatService] Agent run completed: session={session_id}")
@staticmethod
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
try:
from config import conf
if not conf().get("conversation_persistence", True):
return
except Exception:
pass
try:
from agent.memory import get_conversation_store
get_conversation_store().append_messages(
session_id, new_messages, channel_type=channel_type
)
except Exception as e:
logger.warning(
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
)
class _StreamState:
"""Mutable state shared between the event callback and the run method."""
def __init__(self):
self.segment_id: int = 0
# None means we are not accumulating tool results right now.
# A list means we are in the middle of a tool-execution phase.
self.pending_tool_results: Optional[list] = None
# Maps tool_call_id -> arguments captured from tool_execution_start,
# so that tool_execution_end can attach the correct input args.
self.pending_tool_arguments: dict = {}

View File

@@ -1,11 +1,23 @@
"""
Memory module for AgentMesh
Provides long-term memory capabilities with hybrid search (vector + keyword)
Provides both long-term memory (vector/keyword search) and short-term
conversation history persistence (SQLite).
"""
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
from agent.memory.embedding import create_embedding_provider
from agent.memory.conversation_store import ConversationStore, get_conversation_store
from agent.memory.summarizer import ensure_daily_memory_file
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
__all__ = [
'MemoryManager',
'MemoryConfig',
'get_default_memory_config',
'set_global_memory_config',
'create_embedding_provider',
'ConversationStore',
'get_conversation_store',
'ensure_daily_memory_file',
]

View File

@@ -4,6 +4,7 @@ Text chunking utilities for memory
Splits text into chunks with token limits and overlap
"""
from __future__ import annotations
from typing import List, Tuple
from dataclasses import dataclass

View File

@@ -4,18 +4,25 @@ Memory configuration module
Provides global memory configuration with simplified workspace structure
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path
def _default_workspace():
"""Get default workspace path with proper Windows support"""
from common.utils import expand_path
return expand_path("~/cow")
@dataclass
class MemoryConfig:
"""Configuration for memory storage and search"""
# Storage paths (default: ~/cow)
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
workspace_root: str = field(default_factory=_default_workspace)
# Embedding config
embedding_provider: str = "openai" # "openai" | "local"
@@ -41,9 +48,6 @@ class MemoryConfig:
enable_auto_sync: bool = True
sync_on_search: bool = True
# Memory flush config (独立于模型 context window)
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
def get_workspace(self) -> Path:
"""Get workspace root directory"""

View File

@@ -0,0 +1,618 @@
"""
Conversation history persistence using SQLite.
Design:
- sessions table: per-session metadata (channel_type, last_active, msg_count)
- messages table: individual messages stored as JSON, append-only
- Pruning: age-based only (sessions not updated within N days are deleted)
- Thread-safe via a single in-process lock
Storage path: ~/cow/sessions/conversations.db
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from common.log import logger
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
channel_type TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
last_active INTEGER NOT NULL,
msg_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE (session_id, seq)
);
CREATE INDEX IF NOT EXISTS idx_messages_session
ON messages (session_id, seq);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
ON sessions (last_active);
"""
# Migration: add channel_type column to existing databases that predate it.
_MIGRATION_ADD_CHANNEL_TYPE = """
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
"""
DEFAULT_MAX_AGE_DAYS: int = 30
def _is_visible_user_message(content: Any) -> bool:
"""
Return True when a user-role message represents actual user input
(not an internal tool_result injected by the agent loop).
"""
if isinstance(content, str):
return bool(content.strip())
if isinstance(content, list):
return any(
isinstance(b, dict) and b.get("type") == "text"
for b in content
)
return False
def _extract_display_text(content: Any) -> str:
"""
Extract the human-readable text portion from a message content value.
Returns an empty string for tool_use / tool_result blocks.
"""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
"""
Extract tool_use blocks from an assistant message content.
Returns a list of {name, arguments} dicts (result filled in later).
"""
if not isinstance(content, list):
return []
return [
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
for b in content
if isinstance(b, dict) and b.get("type") == "tool_use"
]
def _extract_tool_results(content: Any) -> Dict[str, str]:
"""
Extract tool_result blocks from a user message, keyed by tool_use_id.
"""
if not isinstance(content, list):
return {}
results = {}
for b in content:
if not isinstance(b, dict) or b.get("type") != "tool_result":
continue
tool_id = b.get("tool_use_id", "")
result_content = b.get("content", "")
if isinstance(result_content, list):
result_content = "\n".join(
rb.get("text", "") for rb in result_content
if isinstance(rb, dict) and rb.get("type") == "text"
)
results[tool_id] = str(result_content)
return results
def _group_into_display_turns(
rows: List[tuple],
) -> List[Dict[str, Any]]:
"""
Convert raw (role, content_json, created_at) DB rows into display turns.
One display turn = one visible user message + one merged assistant reply.
All intermediate assistant messages (those carrying tool_use) and the final
assistant text reply produced for the same user query are collapsed into a
single assistant turn, exactly matching the live SSE rendering where tools
and the final answer appear inside the same bubble.
Grouping rules:
- A visible user message starts a new group.
- tool_result user messages are internal; their content is attached to the
matching tool_use entry via tool_use_id and they never become own turns.
- All assistant messages within a group are merged:
* tool_use blocks → tool_calls list (result filled from tool_results)
* text blocks → last non-empty text becomes the display content
"""
# ------------------------------------------------------------------ #
# Pass 1: split rows into groups, each starting with a visible user msg
# ------------------------------------------------------------------ #
# group = (user_row | None, [subsequent_rows])
# user_row: (content, created_at)
groups: List[tuple] = []
cur_user: Optional[tuple] = None
cur_rest: List[tuple] = []
started = False
for role, raw_content, created_at in rows:
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
if role == "user" and _is_visible_user_message(content):
if started:
groups.append((cur_user, cur_rest))
cur_user = (content, created_at)
cur_rest = []
started = True
else:
cur_rest.append((role, content, created_at))
if started:
groups.append((cur_user, cur_rest))
# ------------------------------------------------------------------ #
# Pass 2: build display turns from each group
# ------------------------------------------------------------------ #
turns: List[Dict[str, Any]] = []
for user_row, rest in groups:
# User turn
if user_row:
content, created_at = user_row
text = _extract_display_text(content)
if text:
turns.append({"role": "user", "content": text, "created_at": created_at})
# Collect all tool_calls and tool_results from the rest of the group
all_tool_calls: List[Dict[str, Any]] = []
tool_results: Dict[str, str] = {}
final_text = ""
final_ts: Optional[int] = None
for role, content, created_at in rest:
if role == "user":
tool_results.update(_extract_tool_results(content))
elif role == "assistant":
tcs = _extract_tool_calls(content)
all_tool_calls.extend(tcs)
t = _extract_display_text(content)
if t:
final_text = t
final_ts = created_at
# Attach tool results to their matching tool_call entries
for tc in all_tool_calls:
tc["result"] = tool_results.get(tc.get("id", ""), "")
if final_text or all_tool_calls:
turns.append({
"role": "assistant",
"content": final_text,
"tool_calls": all_tool_calls,
"created_at": final_ts or (user_row[1] if user_row else 0),
})
return turns
class ConversationStore:
"""
SQLite-backed store for per-session conversation history.
Usage:
store = ConversationStore(db_path)
store.append_messages("user_123", new_messages, channel_type="feishu")
msgs = store.load_messages("user_123", max_turns=30)
"""
def __init__(self, db_path: Path):
self._db_path = db_path
self._lock = threading.Lock()
self._init_db()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def load_messages(
self,
session_id: str,
max_turns: int = 30,
) -> List[Dict[str, Any]]:
"""
Load the most recent messages for a session, for injection into the LLM.
ALL message types (user text, assistant tool_use, tool_result) are returned
in their original JSON form so the LLM can reconstruct the full context.
max_turns is a *visible-turn* count: we count only user messages whose
content is actual user text (not tool_result blocks). This prevents
tool-heavy sessions from exhausting the turn budget prematurely.
Args:
session_id: Unique session identifier.
max_turns: Maximum number of visible user-assistant turns to keep.
Returns:
Chronologically ordered list of message dicts (role, content).
"""
with self._lock:
conn = self._connect()
try:
rows = conn.execute(
"""
SELECT seq, role, content
FROM messages
WHERE session_id = ?
ORDER BY seq DESC
""",
(session_id,),
).fetchall()
finally:
conn.close()
if not rows:
return []
# Walk newest-to-oldest counting *visible* user turns (actual user text,
# not tool_result injections). Record the seq of every visible user
# message so we can find a clean cut point later.
visible_turn_seqs: List[int] = [] # newest first
for seq, role, raw_content in rows:
if role != "user":
continue
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
if _is_visible_user_message(content):
visible_turn_seqs.append(seq)
# Determine the seq of the oldest visible user message we want to keep.
# If the total turns fit within max_turns, keep everything.
if len(visible_turn_seqs) <= max_turns:
cutoff_seq = None # keep all
else:
# The Nth visible user message (0-indexed) is the oldest we keep.
cutoff_seq = visible_turn_seqs[max_turns - 1]
# Build result in chronological order, starting from cutoff.
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
# never mid-group, so tool_use / tool_result pairs are always complete.
result = []
for seq, role, raw_content in reversed(rows):
if cutoff_seq is not None and seq < cutoff_seq:
continue
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
result.append({"role": role, "content": content})
return result
def append_messages(
self,
session_id: str,
messages: List[Dict[str, Any]],
channel_type: str = "",
) -> None:
"""
Append new messages to a session's history.
Seq numbers continue from the session's current maximum, so
concurrent callers on distinct sessions never collide.
Args:
session_id: Unique session identifier.
messages: List of message dicts to append.
channel_type: Source channel (e.g. "feishu", "web", "wechat").
Only written on session creation; ignored on update.
"""
if not messages:
return
now = int(time.time())
with self._lock:
conn = self._connect()
try:
with conn:
# INSERT OR IGNORE creates the row on first visit;
# the UPDATE always refreshes last_active.
# Avoids ON CONFLICT...DO UPDATE (requires SQLite >= 3.24).
conn.execute(
"""
INSERT OR IGNORE INTO sessions
(session_id, channel_type, created_at, last_active, msg_count)
VALUES (?, ?, ?, ?, 0)
""",
(session_id, channel_type, now, now),
)
conn.execute(
"UPDATE sessions SET last_active = ? WHERE session_id = ?",
(now, session_id),
)
# Determine starting seq for the new batch.
row = conn.execute(
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
(session_id,),
).fetchone()
next_seq = row[0] + 1
for msg in messages:
role = msg.get("role", "")
content = json.dumps(
msg.get("content", ""), ensure_ascii=False
)
conn.execute(
"""
INSERT OR IGNORE INTO messages
(session_id, seq, role, content, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(session_id, next_seq, role, content, now),
)
next_seq += 1
conn.execute(
"""
UPDATE sessions
SET msg_count = (
SELECT COUNT(*) FROM messages WHERE session_id = ?
)
WHERE session_id = ?
""",
(session_id, session_id),
)
finally:
conn.close()
def clear_session(self, session_id: str) -> None:
"""Delete all messages and the session record for a given session_id."""
with self._lock:
conn = self._connect()
try:
with conn:
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
conn.execute(
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
)
finally:
conn.close()
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
"""
Delete sessions that have not been active within max_age_days.
Args:
max_age_days: Override the default retention period.
Returns:
Number of sessions deleted.
"""
try:
from config import conf
max_age = max_age_days or conf().get(
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
)
except Exception:
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
cutoff = int(time.time()) - max_age * 86400
deleted = 0
with self._lock:
conn = self._connect()
try:
with conn:
stale = conn.execute(
"SELECT session_id FROM sessions WHERE last_active < ?",
(cutoff,),
).fetchall()
for (sid,) in stale:
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (sid,)
)
conn.execute(
"DELETE FROM sessions WHERE session_id = ?", (sid,)
)
deleted += 1
finally:
conn.close()
if deleted:
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
return deleted
def load_history_page(
self,
session_id: str,
page: int = 1,
page_size: int = 20,
) -> Dict[str, Any]:
"""
Load a page of conversation history for UI display, grouped into turns.
Each "turn" maps to one of:
- A user message (role="user", content=str)
- An assistant message (role="assistant", content=str,
tool_calls=[{name, arguments, result}] when tools were used)
Internal tool_result user messages are merged into the preceding
assistant entry's tool_calls list and never appear as standalone items.
Pages are numbered from 1 (most recent). Messages within a page are
returned in chronological order.
Returns:
{
"messages": [
{
"role": "user" | "assistant",
"content": str,
"tool_calls": [...], # assistant only, may be []
"created_at": int,
},
...
],
"total": <visible turn count>,
"page": <current page>,
"page_size": <page_size>,
"has_more": bool,
}
"""
page = max(1, page)
with self._lock:
conn = self._connect()
try:
rows = conn.execute(
"""
SELECT role, content, created_at
FROM messages
WHERE session_id = ?
ORDER BY seq ASC
""",
(session_id,),
).fetchall()
finally:
conn.close()
visible = _group_into_display_turns(rows)
total = len(visible)
offset = (page - 1) * page_size
page_items = list(reversed(visible))[offset: offset + page_size]
page_items = list(reversed(page_items))
return {
"messages": page_items,
"total": total,
"page": page,
"page_size": page_size,
"has_more": offset + page_size < total,
}
def get_stats(self) -> Dict[str, Any]:
"""Return basic stats keyed by channel_type, for monitoring."""
with self._lock:
conn = self._connect()
try:
total_sessions = conn.execute(
"SELECT COUNT(*) FROM sessions"
).fetchone()[0]
total_messages = conn.execute(
"SELECT COUNT(*) FROM messages"
).fetchone()[0]
by_channel = conn.execute(
"""
SELECT channel_type, COUNT(*) as cnt
FROM sessions
GROUP BY channel_type
ORDER BY cnt DESC
"""
).fetchall()
return {
"total_sessions": total_sessions,
"total_messages": total_messages,
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
}
finally:
conn.close()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _init_db(self) -> None:
self._db_path.parent.mkdir(parents=True, exist_ok=True)
conn = self._connect()
try:
conn.executescript(_DDL)
conn.commit()
self._migrate(conn)
finally:
conn.close()
def _migrate(self, conn: sqlite3.Connection) -> None:
"""Apply incremental schema migrations on existing databases."""
cols = {
row[1]
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
}
if "channel_type" not in cols:
try:
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
conn.commit()
logger.info("[ConversationStore] Migrated: added channel_type column")
except Exception as e:
logger.warning(f"[ConversationStore] Migration failed: {e}")
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(str(self._db_path), timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_store_instance: Optional[ConversationStore] = None
_store_lock = threading.Lock()
def get_conversation_store() -> ConversationStore:
"""
Return the process-wide ConversationStore singleton.
Reuses the long-term memory database so the project stays with a single
SQLite file: ~/cow/memory/long-term/index.db
The conversation tables (sessions / messages) are separate from the
memory tables (memory_chunks / file_metadata) — no conflicts.
"""
global _store_instance
if _store_instance is not None:
return _store_instance
with _store_lock:
if _store_instance is not None:
return _store_instance
try:
from agent.memory.config import get_default_memory_config
db_path = get_default_memory_config().get_db_path()
except Exception:
from common.utils import expand_path
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
_store_instance = ConversationStore(db_path)
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
return _store_instance

View File

@@ -32,21 +32,25 @@ class EmbeddingProvider(ABC):
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
extra_headers: Optional extra headers to include in API requests
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
self.extra_headers = extra_headers or {}
if not self.api_key:
raise ValueError("OpenAI API key is required")
# Validate API key
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
@@ -58,16 +62,29 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
data = {
"input": input_data,
"model": self.model
}
response = requests.post(url, headers=headers, json=data, timeout=30)
response.raise_for_status()
return response.json()
try:
response = requests.post(url, headers=headers, json=data, timeout=5)
response.raise_for_status()
return response.json()
except requests.exceptions.ConnectionError as e:
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
except requests.exceptions.Timeout as e:
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
elif e.response.status_code == 429:
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
else:
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
@@ -121,28 +138,30 @@ def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
Only supports OpenAI embedding via REST API.
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
If initialization fails, caller should fall back to keyword-only search.
Args:
provider: Provider name (only "openai" is supported)
provider: Provider name ("openai" or "linkai")
model: Model name (default: text-embedding-3-small)
api_key: OpenAI API key (required)
api_base: API base URL (default: https://api.openai.com/v1)
api_key: API key (required)
api_base: API base URL
extra_headers: Optional extra headers to include in API requests
Returns:
EmbeddingProvider instance
Raises:
ValueError: If provider is not "openai" or api_key is missing
ValueError: If provider is unsupported or api_key is missing
"""
if provider != "openai":
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
if provider not in ("openai", "linkai"):
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)

View File

@@ -50,28 +50,48 @@ class MemoryManager:
overlap_tokens=self.config.chunk_overlap_tokens
)
# Initialize embedding provider (optional)
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
self.embedding_provider = None
if embedding_provider:
self.embedding_provider = embedding_provider
else:
# Try to create embedding provider, but allow failure
# Try OpenAI first
try:
# Get API key from environment or config
api_key = os.environ.get('OPENAI_API_KEY')
api_base = os.environ.get('OPENAI_API_BASE')
self.embedding_provider = create_embedding_provider(
provider=self.config.embedding_provider,
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
if api_key:
self.embedding_provider = create_embedding_provider(
provider="openai",
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
except Exception as e:
# Embedding provider failed, but that's OK
# We can still use keyword search and file operations
from common.log import logger
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
# Fallback to LinkAI
if self.embedding_provider is None:
try:
linkai_key = os.environ.get('LINKAI_API_KEY')
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
if linkai_key:
from common.utils import get_cloud_headers
cloud_headers = get_cloud_headers(linkai_key)
cloud_headers.pop("Authorization", None)
self.embedding_provider = create_embedding_provider(
provider="linkai",
model=self.config.embedding_model,
api_key=linkai_key,
api_base=f"{linkai_base}/v1",
extra_headers=cloud_headers,
)
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
if self.embedding_provider is None:
from common.log import logger
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
# Initialize memory flush manager
@@ -304,7 +324,7 @@ class MemoryManager:
):
"""Sync a single file"""
# Compute file hash
content = file_path.read_text()
content = file_path.read_text(encoding='utf-8')
file_hash = MemoryStorage.compute_hash(content)
# Get relative path
@@ -363,182 +383,35 @@ class MemoryManager:
size=stat.st_size
)
def should_flush_memory(
def flush_memory(
self,
current_tokens: int = 0
) -> bool:
"""
Check if memory flush should be triggered
独立的 flush 触发机制,不依赖模型 context window。
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
Args:
current_tokens: Current session token count
Returns:
True if memory flush should run
"""
return self.flush_manager.should_flush(
current_tokens=current_tokens,
token_threshold=self.config.flush_token_threshold,
turn_threshold=self.config.flush_turn_threshold
)
def increment_turn(self):
"""增加对话轮数计数(每次用户消息+AI回复算一轮"""
self.flush_manager.increment_turn()
async def execute_memory_flush(
self,
agent_executor,
current_tokens: int,
messages: list,
user_id: Optional[str] = None,
**executor_kwargs
reason: str = "threshold",
max_messages: int = 10,
) -> bool:
"""
Execute memory flush before compaction
This runs a silent agent turn to write durable memories to disk.
Similar to clawdbot's pre-compaction memory flush.
Flush conversation summary to daily memory file.
Args:
agent_executor: Async function to execute agent with prompt
current_tokens: Current session token count
messages: Conversation message list
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
reason: "threshold" | "overflow" | "daily_summary"
max_messages: Max recent messages to include (0 = all)
Returns:
True if flush completed successfully
Example:
>>> async def run_agent(prompt, system_prompt, silent=False):
... # Your agent execution logic
... pass
>>>
>>> if manager.should_flush_memory(current_tokens=100000):
... await manager.execute_memory_flush(
... agent_executor=run_agent,
... current_tokens=100000
... )
True if content was written
"""
success = await self.flush_manager.execute_flush(
agent_executor=agent_executor,
current_tokens=current_tokens,
success = self.flush_manager.flush_from_messages(
messages=messages,
user_id=user_id,
**executor_kwargs
reason=reason,
max_messages=max_messages,
)
if success:
# Mark dirty so next search will sync the new memories
self._dirty = True
return success
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
"""
Build natural memory guidance for agent system prompt
Following clawdbot's approach:
1. Load MEMORY.md as bootstrap context (blends into background)
2. Load daily files on-demand via memory_search tool
3. Agent should NOT proactively mention memories unless user asks
Args:
lang: Language for guidance ("en" or "zh")
include_context: Whether to include bootstrap memory context (default: True)
MEMORY.md is loaded as background context (like clawdbot)
Daily files are accessed via memory_search tool
Returns:
Memory guidance text (and optionally context) for system prompt
"""
today_file = self.flush_manager.get_today_memory_file().name
if lang == "zh":
guidance = f"""## 记忆系统
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
- 长期信息 → MEMORY.md
- 当天笔记 → memory/{today_file}
- 静默存储,仅在明确要求时确认
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
else:
guidance = f"""## Memory System
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
- Durable info → MEMORY.md
- Daily notes → memory/{today_file}
- Store silently; confirm only when explicitly requested
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
if include_context:
# Load bootstrap context (MEMORY.md only, like clawdbot)
bootstrap_context = self.load_bootstrap_memories()
if bootstrap_context:
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
return guidance
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
"""
Load bootstrap memory files for session start
Following clawdbot's design:
- Only loads MEMORY.md from workspace root (long-term curated memory)
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
- User-specific MEMORY.md is also loaded if user_id provided
Returns memory content WITHOUT obvious headers so it blends naturally
into the context as background knowledge.
Args:
user_id: Optional user ID for user-specific memories
Returns:
Memory content to inject into system prompt (blends naturally as background context)
"""
workspace_dir = self.config.get_workspace()
memory_dir = self.config.get_memory_dir()
sections = []
# 1. Load MEMORY.md from workspace root (long-term curated memory)
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
memory_file = Path(workspace_dir) / "MEMORY.md"
if memory_file.exists():
try:
content = memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read MEMORY.md: {e}")
# 2. Load user-specific MEMORY.md if user_id provided
if user_id:
user_memory_dir = memory_dir / "users" / user_id
user_memory_file = user_memory_dir / "MEMORY.md"
if user_memory_file.exists():
try:
content = user_memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read user memory: {e}")
if not sections:
return ""
# Join sections without obvious headers - let memories blend naturally
# This makes the agent feel like it "just knows" rather than "checking memory files"
return "\n\n".join(sections)
def get_status(self) -> Dict[str, Any]:
"""Get memory status"""
stats = self.storage.get_stats()
@@ -568,6 +441,37 @@ class MemoryManager:
content = f"{path}:{start_line}:{end_line}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
@staticmethod
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
"""
Compute temporal decay multiplier for dated memory files.
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
Daily files like memory/2025-03-01.md decay based on age.
Formula: multiplier = exp(-ln2/half_life * age_in_days)
"""
import re
import math
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
if not match:
return 1.0 # evergreen: MEMORY.md, non-dated files
try:
file_date = datetime(
int(match.group(1)), int(match.group(2)), int(match.group(3))
)
age_days = (datetime.now() - file_date).days
if age_days <= 0:
return 1.0
decay_lambda = math.log(2) / half_life_days
return math.exp(-decay_lambda * age_days)
except (ValueError, OverflowError):
return 1.0
def _merge_results(
self,
vector_results: List[SearchResult],
@@ -575,8 +479,7 @@ class MemoryManager:
vector_weight: float,
keyword_weight: float
) -> List[SearchResult]:
"""Merge vector and keyword search results"""
# Create a map by (path, start_line, end_line)
"""Merge vector and keyword search results with temporal decay for dated files"""
merged_map = {}
for result in vector_results:
@@ -598,7 +501,6 @@ class MemoryManager:
'keyword_score': result.score
}
# Calculate combined scores
merged_results = []
for entry in merged_map.values():
combined_score = (
@@ -606,7 +508,11 @@ class MemoryManager:
keyword_weight * entry['keyword_score']
)
# Apply temporal decay for dated memory files
result = entry['result']
decay = self._compute_temporal_decay(result.path)
combined_score *= decay
merged_results.append(SearchResult(
path=result.path,
start_line=result.start_line,
@@ -617,6 +523,5 @@ class MemoryManager:
user_id=result.user_id
))
# Sort by score
merged_results.sort(key=lambda r: r.score, reverse=True)
return merged_results

167
agent/memory/service.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Memory service for handling memory query operations via cloud protocol.
Provides a unified interface for listing and reading memory files,
callable from the cloud client (LinkAI) or a future web console.
Memory file layout (under workspace_root):
MEMORY.md -> type: global
memory/2026-02-20.md -> type: daily
"""
import os
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from common.log import logger
class MemoryService:
"""
High-level service for memory file queries.
Operates directly on the filesystem — no MemoryManager dependency.
"""
def __init__(self, workspace_root: str):
"""
:param workspace_root: Workspace root directory (e.g. ~/cow)
"""
self.workspace_root = workspace_root
self.memory_dir = os.path.join(workspace_root, "memory")
# ------------------------------------------------------------------
# list — paginated file metadata
# ------------------------------------------------------------------
def list_files(self, page: int = 1, page_size: int = 20) -> dict:
"""
List all memory files with metadata (without content).
Returns::
{
"page": 1,
"page_size": 20,
"total": 15,
"list": [
{"filename": "MEMORY.md", "type": "global", "size": 2048, "updated_at": "2026-02-20 10:00:00"},
{"filename": "2026-02-20.md", "type": "daily", "size": 512, "updated_at": "2026-02-20 09:30:00"},
...
]
}
"""
files: List[dict] = []
# 1. Global memory — MEMORY.md in workspace root
global_path = os.path.join(self.workspace_root, "MEMORY.md")
if os.path.isfile(global_path):
files.append(self._file_info(global_path, "MEMORY.md", "global"))
# 2. Daily memory files — memory/*.md (sorted newest first)
if os.path.isdir(self.memory_dir):
daily_files = []
for name in os.listdir(self.memory_dir):
full = os.path.join(self.memory_dir, name)
if os.path.isfile(full) and name.endswith(".md"):
daily_files.append((name, full))
# Sort by filename descending (newest date first)
daily_files.sort(key=lambda x: x[0], reverse=True)
for name, full in daily_files:
files.append(self._file_info(full, name, "daily"))
total = len(files)
# Paginate
start = (page - 1) * page_size
end = start + page_size
page_items = files[start:end]
return {
"page": page,
"page_size": page_size,
"total": total,
"list": page_items,
}
# ------------------------------------------------------------------
# content — read a single file
# ------------------------------------------------------------------
def get_content(self, filename: str) -> dict:
"""
Read the full content of a memory file.
:param filename: File name, e.g. ``MEMORY.md`` or ``2026-02-20.md``
:return: dict with ``filename`` and ``content``
:raises FileNotFoundError: if the file does not exist
"""
path = self._resolve_path(filename)
if not os.path.isfile(path):
raise FileNotFoundError(f"Memory file not found: {filename}")
with open(path, "r", encoding="utf-8") as f:
content = f.read()
return {
"filename": filename,
"content": content,
}
# ------------------------------------------------------------------
# dispatch — single entry point for protocol messages
# ------------------------------------------------------------------
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
"""
Dispatch a memory management action.
:param action: ``list`` or ``content``
:param payload: action-specific payload
:return: protocol-compatible response dict
"""
payload = payload or {}
try:
if action == "list":
page = payload.get("page", 1)
page_size = payload.get("page_size", 20)
result_payload = self.list_files(page=page, page_size=page_size)
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
elif action == "content":
filename = payload.get("filename")
if not filename:
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
result_payload = self.get_content(filename)
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
else:
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
except FileNotFoundError as e:
return {"action": action, "code": 404, "message": str(e), "payload": None}
except Exception as e:
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
return {"action": action, "code": 500, "message": str(e), "payload": None}
# ------------------------------------------------------------------
# internal helpers
# ------------------------------------------------------------------
def _resolve_path(self, filename: str) -> str:
"""
Resolve a filename to its absolute path.
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
- ``2026-02-20.md`` → ``{workspace_root}/memory/2026-02-20.md``
"""
if filename == "MEMORY.md":
return os.path.join(self.workspace_root, filename)
return os.path.join(self.memory_dir, filename)
@staticmethod
def _file_info(path: str, filename: str, file_type: str) -> dict:
"""Build a file metadata dict."""
stat = os.stat(path)
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
return {
"filename": filename,
"type": file_type,
"size": stat.st_size,
"updated_at": updated_at,
}

View File

@@ -4,6 +4,7 @@ Storage layer for memory using SQLite + FTS5
Provides vector and keyword search capabilities
"""
from __future__ import annotations
import sqlite3
import json
import hashlib
@@ -70,7 +71,7 @@ class MemoryStorage:
self.fts5_available = self._check_fts5_support()
if not self.fts5_available:
from common.log import logger
logger.warning("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
# Check database integrity
try:
@@ -508,7 +509,7 @@ class MemoryStorage:
"""Destructor to ensure connection is closed"""
try:
self.close()
except:
except Exception:
pass # Ignore errors during cleanup
# Helper methods

View File

@@ -1,225 +1,324 @@
"""
Memory flush manager
Triggers memory flush before context compaction (similar to clawdbot)
Handles memory persistence when conversation context is trimmed or overflows:
- Uses LLM to summarize discarded messages into concise key-information entries
- Writes to daily memory files (lazy creation)
- Deduplicates trim flushes to avoid repeated writes
- Runs summarization asynchronously to avoid blocking normal replies
- Provides daily summary interface for scheduler
"""
from typing import Optional, Callable, Any
import threading
from typing import Optional, Callable, Any, List, Dict
from pathlib import Path
from datetime import datetime
from common.log import logger
SUMMARIZE_SYSTEM_PROMPT = """你是一个记忆提取助手。你的任务是从对话记录中提取值得记住的信息,生成简洁的记忆摘要。
输出要求:
1. 以事件/关键信息为维度记录,每条一行,用 "- " 开头
2. 记录有价值的关键信息,例如用户提出的要求及助手的解决方案,对话中涉及的事实信息,用户的偏好、决策或重要结论
3. 每条摘要需要简明扼要,只保留关键信息
4. 直接输出摘要内容,不要加任何前缀说明
5. 当对话没有任何记录价值例如只是简单问候,可回复"\""""
SUMMARIZE_USER_PROMPT = """请从以下对话记录中提取关键信息,生成记忆摘要:
{conversation}"""
class MemoryFlushManager:
"""
Manages memory flush operations before context compaction
Manages memory flush operations.
Similar to clawdbot's memory flush mechanism:
- Triggers when context approaches token limit
- Runs a silent agent turn to write memories to disk
- Uses memory/YYYY-MM-DD.md for daily notes
- Uses MEMORY.md (workspace root) for long-term curated memories
Flush is triggered by agent_stream in two scenarios:
1. Context trim: _trim_messages discards old turns → flush discarded content
2. Context overflow: API rejects request → emergency flush before clearing
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
"""
def __init__(
self,
workspace_dir: Path,
llm_model: Optional[Any] = None
llm_model: Optional[Any] = None,
):
"""
Initialize memory flush manager
Args:
workspace_dir: Workspace directory
llm_model: LLM model for agent execution (optional)
"""
self.workspace_dir = workspace_dir
self.llm_model = llm_model
self.memory_dir = workspace_dir / "memory"
self.memory_dir.mkdir(parents=True, exist_ok=True)
# Tracking
self.last_flush_token_count: Optional[int] = None
self.last_flush_timestamp: Optional[datetime] = None
self.turn_count: int = 0 # 对话轮数计数器
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
def should_flush(
self,
current_tokens: int = 0,
token_threshold: int = 50000,
turn_threshold: int = 20
) -> bool:
"""
Determine if memory flush should be triggered
独立的 flush 触发机制,不依赖模型 context window:
- Token 阈值: 达到 50K tokens 时触发
- 轮次阈值: 达到 20 轮对话时触发
Args:
current_tokens: Current session token count
token_threshold: Token threshold to trigger flush (default: 50K)
turn_threshold: Turn threshold to trigger flush (default: 20)
Returns:
True if flush should run
"""
# 检查 token 阈值
if current_tokens > 0 and current_tokens >= token_threshold:
# 避免重复 flush
if self.last_flush_token_count is not None:
if current_tokens <= self.last_flush_token_count + 5000:
return False
return True
# 检查轮次阈值
if self.turn_count >= turn_threshold:
return True
return False
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get today's memory file path: memory/YYYY-MM-DD.md
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to today's memory file
"""
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / f"{today}.md"
if ensure_exists:
user_dir.mkdir(parents=True, exist_ok=True)
today_file = user_dir / f"{today}.md"
else:
return self.memory_dir / f"{today}.md"
today_file = self.memory_dir / f"{today}.md"
if ensure_exists and not today_file.exists():
today_file.parent.mkdir(parents=True, exist_ok=True)
today_file.write_text(f"# Daily Memory: {today}\n\n")
return today_file
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get main memory file path: MEMORY.md (workspace root)
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to main memory file
"""
"""Get main memory file path: MEMORY.md (workspace root)"""
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / "MEMORY.md"
else:
# Return workspace root MEMORY.md
return Path(self.workspace_dir) / "MEMORY.md"
def create_flush_prompt(self) -> str:
"""
Create prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
"""
today = datetime.now().strftime("%Y-%m-%d")
return (
f"Pre-compaction memory flush. "
f"Store durable memories now (use memory/{today}.md for daily notes; "
f"create memory/ if needed). "
f"\n\n"
f"重要提示:\n"
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
)
def create_flush_system_prompt(self) -> str:
"""
Create system prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
"""
return (
"Pre-compaction memory flush turn. "
"The session is near auto-compaction; capture durable memories to disk. "
"\n\n"
"记忆写入原则:\n"
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens\n"
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
"\n"
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
"\n"
"3. 判断标准:\n"
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
"\n"
"You may reply, but usually NO_REPLY is correct."
)
async def execute_flush(
self,
agent_executor: Callable,
current_tokens: int,
user_id: Optional[str] = None,
**executor_kwargs
) -> bool:
"""
Execute memory flush by running a silent agent turn
Args:
agent_executor: Function to execute agent with prompt
current_tokens: Current token count
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
Returns:
True if flush completed successfully
"""
try:
# Create flush prompts
prompt = self.create_flush_prompt()
system_prompt = self.create_flush_system_prompt()
# Execute agent turn (silent, no user-visible reply expected)
await agent_executor(
prompt=prompt,
system_prompt=system_prompt,
silent=True, # NO_REPLY expected
**executor_kwargs
)
# Track flush
self.last_flush_token_count = current_tokens
self.last_flush_timestamp = datetime.now()
self.turn_count = 0 # 重置轮数计数器
return True
except Exception as e:
print(f"Memory flush failed: {e}")
return False
def increment_turn(self):
"""增加对话轮数计数"""
self.turn_count += 1
def get_status(self) -> dict:
"""Get memory flush status"""
return {
'last_flush_tokens': self.last_flush_token_count,
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
'today_file': str(self.get_today_memory_file()),
'main_file': str(self.get_main_memory_file())
}
# ---- Flush execution (called by agent_stream or scheduler) ----
def flush_from_messages(
self,
messages: List[Dict],
user_id: Optional[str] = None,
reason: str = "trim",
max_messages: int = 0,
) -> bool:
"""
Asynchronously summarize and flush messages to daily memory.
Deduplication runs synchronously, then LLM summarization + file write
run in a background thread so the main reply flow is never blocked.
Args:
messages: Conversation message list (OpenAI/Claude format)
user_id: Optional user ID for user-scoped memory
reason: Why flush was triggered ("trim" | "overflow" | "daily_summary")
max_messages: Max recent messages to summarize (0 = all)
Returns:
True if flush was dispatched
"""
try:
import hashlib
deduped = []
for m in messages:
text = self._extract_text_from_content(m.get("content", ""))
if not text or not text.strip():
continue
h = hashlib.md5(text.encode("utf-8")).hexdigest()
if h not in self._trim_flushed_hashes:
self._trim_flushed_hashes.add(h)
deduped.append(m)
if not deduped:
return False
import copy
snapshot = copy.deepcopy(deduped)
thread = threading.Thread(
target=self._flush_worker,
args=(snapshot, user_id, reason, max_messages),
daemon=True,
)
thread.start()
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
return True
except Exception as e:
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
return False
def _flush_worker(
self,
messages: List[Dict],
user_id: Optional[str],
reason: str,
max_messages: int,
):
"""Background worker: summarize with LLM and write to daily file."""
try:
summary = self._summarize_messages(messages, max_messages)
if not summary or not summary.strip() or summary.strip() == "":
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
return
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
if reason == "overflow":
header = f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})"
note = "The following conversation was trimmed due to context overflow:\n"
elif reason == "trim":
header = f"## Trimmed Context ({datetime.now().strftime('%H:%M')})"
note = ""
elif reason == "daily_summary":
header = f"## Daily Summary ({datetime.now().strftime('%H:%M')})"
note = ""
else:
header = f"## Session Notes ({datetime.now().strftime('%H:%M')})"
note = ""
flush_entry = f"\n{header}\n\n{note}{summary}\n"
with open(daily_file, "a", encoding="utf-8") as f:
f.write(flush_entry)
self.last_flush_timestamp = datetime.now()
logger.info(f"[MemoryFlush] Wrote to {daily_file.name} (reason={reason}, chars={len(summary)})")
except Exception as e:
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
def create_daily_summary(
self,
messages: List[Dict],
user_id: Optional[str] = None
) -> bool:
"""
Generate end-of-day summary. Called by daily timer.
Skips if messages haven't changed since last flush.
"""
import hashlib
content = "".join(
self._extract_text_from_content(m.get("content", ""))
for m in messages
)
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
if content_hash == self._last_flushed_content_hash:
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
return False
self._last_flushed_content_hash = content_hash
return self.flush_from_messages(
messages=messages,
user_id=user_id,
reason="daily_summary",
max_messages=0,
)
# ---- Internal helpers ----
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
"""
Summarize conversation messages using LLM, with rule-based fallback.
"""
conversation_text = self._format_conversation_for_summary(messages, max_messages)
if not conversation_text.strip():
return ""
# Try LLM summarization first
if self.llm_model:
try:
summary = self._call_llm_for_summary(conversation_text)
if summary and summary.strip() and summary.strip() != "":
return summary.strip()
except Exception as e:
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
return self._extract_summary_fallback(messages, max_messages)
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
"""Format messages into readable conversation text for LLM summarization."""
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
lines = []
for msg in msgs:
role = msg.get("role", "")
text = self._extract_text_from_content(msg.get("content", ""))
if not text or not text.strip():
continue
text = text.strip()
if role == "user":
lines.append(f"用户: {text[:500]}")
elif role == "assistant":
lines.append(f"助手: {text[:500]}")
return "\n".join(lines)
def _call_llm_for_summary(self, conversation_text: str) -> str:
"""Call LLM to generate a concise summary of the conversation."""
from agent.protocol.models import LLMRequest
request = LLMRequest(
messages=[{"role": "user", "content": SUMMARIZE_USER_PROMPT.format(conversation=conversation_text)}],
temperature=0,
max_tokens=500,
stream=False,
system=SUMMARIZE_SYSTEM_PROMPT,
)
response = self.llm_model.call(request)
if isinstance(response, dict):
if response.get("error"):
raise RuntimeError(response.get("message", "LLM call failed"))
# OpenAI format
choices = response.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "")
# Handle response object with attribute access (e.g. OpenAI SDK response)
if hasattr(response, "choices") and response.choices:
return response.choices[0].message.content or ""
return ""
@staticmethod
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
"""Rule-based fallback when LLM is unavailable."""
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
items = []
for msg in msgs:
role = msg.get("role", "")
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
if not text or not text.strip():
continue
text = text.strip()
if role == "user":
if len(text) <= 5:
continue
items.append(f"- 用户请求: {text[:200]}")
elif role == "assistant":
first_line = text.split("\n")[0].strip()
if len(first_line) > 10:
items.append(f"- 处理结果: {first_line[:200]}")
return "\n".join(items[:15])
@staticmethod
def _extract_text_from_content(content) -> str:
"""Extract plain text from message content (string or content blocks)."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts)
return ""
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
"""
Create default memory files if they don't exist
Create essential memory files if they don't exist.
Only creates MEMORY.md; daily files are created lazily on first write.
Args:
workspace_dir: Workspace directory
@@ -228,7 +327,7 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
# Create main MEMORY.md in workspace root
# Create main MEMORY.md in workspace root (always needed for bootstrap)
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
@@ -237,14 +336,28 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
main_memory = Path(workspace_dir) / "MEMORY.md"
if not main_memory.exists():
# Create empty file or with minimal structure (no obvious "Memory" header)
# Following clawdbot's approach: memories should blend naturally into context
main_memory.write_text("")
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
"""
Ensure today's daily memory file exists, creating it only when actually needed.
Called lazily before first write to daily memory.
Args:
workspace_dir: Workspace directory
user_id: Optional user ID for user-specific files
Returns:
Path to today's memory file
"""
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
# Create today's memory file
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
today_memory = user_dir / f"{today}.md"
else:
today_memory = memory_dir / f"{today}.md"
@@ -252,5 +365,6 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
if not today_memory.exists():
today_memory.write_text(
f"# Daily Memory: {today}\n\n"
f"Day-to-day notes and running context.\n\n"
)
return today_memory

View File

@@ -4,6 +4,7 @@ System Prompt Builder - 系统提示词构建器
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
"""
from __future__ import annotations
import os
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
@@ -41,21 +42,19 @@ class PromptBuilder:
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
is_first_conversation: bool = False,
**kwargs
) -> str:
"""
构建完整的系统提示词
Args:
base_persona: 基础人格描述会被context_files中的SOUL.md覆盖
base_persona: 基础人格描述会被context_files中的AGENT.md覆盖
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表(SOUL.md, USER.md, README.md等
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
is_first_conversation: 是否为首次对话
**kwargs: 其他参数
Returns:
@@ -71,7 +70,6 @@ class PromptBuilder:
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
is_first_conversation=is_first_conversation,
**kwargs
)
@@ -86,7 +84,6 @@ def build_agent_system_prompt(
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
is_first_conversation: bool = False,
**kwargs
) -> str:
"""
@@ -98,20 +95,19 @@ def build_agent_system_prompt(
3. 记忆系统 - 独立的记忆能力
4. 工作空间 - 工作环境说明
5. 用户身份 - 用户信息(可选)
6. 项目上下文 - SOUL.md, USER.md, AGENTS.md定义人格身份)
6. 项目上下文 - AGENT.md, USER.md, RULE.md, BOOTSTRAP.md定义人格身份、规则、初始化引导
7. 运行时信息 - 元信息(时间、模型等)
Args:
workspace_dir: 工作空间目录
language: 语言 ("zh""en")
base_persona: 基础人格描述(已废弃,由SOUL.md定义
base_persona: 基础人格描述(已废弃,由AGENT.md定义
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
is_first_conversation: 是否为首次对话
**kwargs: 其他参数
Returns:
@@ -132,13 +128,13 @@ def build_agent_system_prompt(
sections.extend(_build_memory_section(memory_manager, tools, language))
# 4. 工作空间(工作环境说明)
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
sections.extend(_build_workspace_section(workspace_dir, language))
# 5. 用户身份(如果有)
if user_identity:
sections.extend(_build_user_identity_section(user_identity, language))
# 6. 项目上下文文件(SOUL.md, USER.md, AGENTS.md - 定义人格)
# 6. 项目上下文文件(AGENT.md, USER.md, RULE.md - 定义人格)
if context_files:
sections.extend(_build_context_files_section(context_files, language))
@@ -150,102 +146,73 @@ def build_agent_system_prompt(
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
"""构建基础身份section - 不再需要,身份由SOUL.md定义"""
# 不再生成基础身份section完全由SOUL.md定义
"""构建基础身份section - 不再需要,身份由AGENT.md定义"""
# 不再生成基础身份section完全由AGENT.md定义
return []
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
"""构建工具说明section"""
lines = [
"## 工具系统",
"",
"你可以使用以下工具来完成任务。工具名称是大小写敏感的,请严格按照列表中的名称调用。",
"",
"### 可用工具",
"",
]
# 工具分类和排序
tool_categories = {
"文件操作": ["read", "write", "edit", "ls", "grep", "find"],
"命令执行": ["bash", "terminal"],
"网络搜索": ["web_search", "web_fetch", "browser"],
"记忆系统": ["memory_search", "memory_get"],
"其他": []
}
# 构建工具映射
tool_map = {}
tool_descriptions = {
"""Build tooling section with concise tool list and call style guide."""
# One-line summaries for known tools (details are in the tool schema)
core_summaries = {
"read": "读取文件内容",
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit。注意单次 write 内容不要超过 10KB超大文件请分步创建",
"edit": "精确编辑文件(追加、修改、删除部分内容)",
"write": "创建或覆盖文件",
"edit": "精确编辑文件",
"ls": "列出目录内容",
"grep": "在文件中搜索内容",
"find": "模式查找文件",
"grep": "搜索文件内容",
"find": "按模式查找文件",
"bash": "执行shell命令",
"terminal": "管理后台进程",
"web_search": "网络搜索(使用搜索引擎)",
"web_search": "网络搜索",
"web_fetch": "获取URL内容",
"browser": "控制浏览器",
"memory_search": "搜索记忆文件",
"memory_get": "取记忆文件内容",
"calculator": "计算器",
"current_time": "获取当前时间",
"memory_search": "搜索记忆",
"memory_get": "取记忆内容",
"env_config": "管理API密钥和技能配置",
"scheduler": "管理定时任务和提醒",
"send": "发送本地文件给用户仅限本地文件URL直接放在回复文本中",
}
# Preferred display order
tool_order = [
"read", "write", "edit", "ls", "grep", "find",
"bash", "terminal",
"web_search", "web_fetch", "browser",
"memory_search", "memory_get",
"env_config", "scheduler", "send",
]
# Build name -> summary mapping for available tools
available = {}
for tool in tools:
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
tool_desc = tool.description if hasattr(tool, 'description') else tool_descriptions.get(tool_name, "")
tool_map[tool_name] = tool_desc
# 按分类添加工具
for category, tool_names in tool_categories.items():
category_tools = [(name, tool_map.get(name, "")) for name in tool_names if name in tool_map]
if category_tools:
lines.append(f"**{category}**:")
for name, desc in category_tools:
if desc:
lines.append(f"- `{name}`: {desc}")
else:
lines.append(f"- `{name}`")
del tool_map[name] # 移除已添加的工具
lines.append("")
# 添加其他未分类的工具
if tool_map:
lines.append("**其他工具**:")
for name, desc in sorted(tool_map.items()):
if desc:
lines.append(f"- `{name}`: {desc}")
else:
lines.append(f"- `{name}`")
lines.append("")
# 工具使用指南
lines.extend([
"### 工具调用风格",
name = tool.name if hasattr(tool, 'name') else str(tool)
available[name] = core_summaries.get(name, "")
# Generate tool lines: ordered tools first, then extras
tool_lines = []
for name in tool_order:
if name in available:
summary = available.pop(name)
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
for name in sorted(available):
summary = available[name]
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
lines = [
"## 🔧 工具系统",
"",
"默认规则: 对于常规、低风险的工具调用,直接调用即可,无需叙述。",
"可用工具(名称大小写敏感,严格按列表调用):",
"\n".join(tool_lines),
"",
"需要叙述的情况:",
"- 多步骤、复杂的任务",
"- 敏感操作(如删除文件)",
"- 用户明确要求解释过程",
"工具调用风格:",
"",
"叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。",
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
"- 持续推进直到任务完成,完成后向用户报告结果。",
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
"- URL链接直接放在回复文本中即可系统会自动处理和渲染。无需下载后使用send工具发送",
"",
"完成标准:",
"- 确保用户的需求得到实际解决,而不仅仅是制定计划。",
"- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题",
"- 每次工具调用后,评估是否已获得足够信息来推进或完成任务",
"- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求",
"",
"**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。",
"",
])
]
return lines
@@ -264,26 +231,34 @@ def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], langua
break
lines = [
"## 技能系统",
"## 🧩 技能系统mandatory",
"",
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目",
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
"",
f"- 如果恰好有一个技能明确适用:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后遵循它",
"- 如果多个技能都适用:选择最具体的一个,然后读取并遵循",
"- 如果没有明确适用的:不要读取任何 SKILL.md",
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
"当有匹配的技能时,应优先使用技能",
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
"- 如果没有技能明确适用:不要读取任何 SKILL.md直接使用通用工具。",
"",
"**约束**: 永远不要一次性读取多个技能;只在选择后再读取",
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作"
"永远不要一次性读取多个技能,只在选择后再读取。",
"",
"以下是可用技能:"
]
# 添加技能列表通过skill_manager获取
try:
skills_prompt = skill_manager.build_skills_prompt()
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
if skills_prompt:
lines.append(skills_prompt.strip())
lines.append("")
else:
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
except Exception as e:
logger.warning(f"Failed to build skills prompt: {e}")
import traceback
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
return lines
@@ -302,8 +277,13 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
if not has_memory_tools:
return []
from datetime import datetime
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
lines = [
"## 记忆系统",
"## 🧠 记忆系统",
"",
"### 检索记忆",
"",
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
"",
@@ -312,13 +292,24 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
"",
"**记忆文件结构**:",
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
f"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
"",
"**写入记忆**:",
"### 写入记忆",
"",
"**主动存储**:遇到以下情况时,应主动将信息写入记忆文件(无需告知用户):",
"",
"- 用户明确要求你记住某些信息",
"- 用户分享了重要的个人偏好、习惯、决策",
"- 对话中产生了重要的结论、方案、约定",
"- 完成了复杂任务,值得记录关键步骤和结果",
"- 发现了用户经常遇到的问题或解决方案",
"",
"**存储规则**:",
f"- 长期有效的核心信息 → `MEMORY.md`(文件保持精简,< 2000 tokens",
f"- 当天的事件、进展、笔记 → `memory/{today_file}`",
"- 追加内容 → `edit` 工具oldText 留空",
"- 修改内容 → `edit` 工具oldText 填写要替换的文本",
"- 新建文件 → `write` 工具",
"- **禁止写入敏感信息**API密钥、令牌等敏感信息严禁写入记忆文件",
"",
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
@@ -334,7 +325,7 @@ def _build_user_identity_section(user_identity: Dict[str, str], language: str) -
return []
lines = [
"## 用户身份",
"## 👤 用户身份",
"",
]
@@ -358,17 +349,17 @@ def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
return []
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
"""构建工作空间section"""
lines = [
"## 工作空间",
"## 📂 工作空间",
"",
f"你的工作目录是: `{workspace_dir}`",
"",
"**路径使用规则** (非常重要):",
"",
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `SOUL.md`",
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
"",
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
@@ -385,61 +376,57 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers
"",
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**",
"",
"- ✅ `SOUL.md`: 已加载 - Agent的人格设定",
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
"- ✅ `AGENTS.md`: 已加载 - 工作空间使用指南",
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
"",
"**交流规范**:",
"**💬 交流规范**:",
"",
"- 对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问",
"- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」",
"- 对话中不要暴露内部技术细节(文件名工具名等),用自然语言表达。例如说「我已记住」而非「已更新 MEMORY.md」",
"- 做真正有帮助的助手,而不是表演式的客套。跳过「好的!」「当然可以!」之类的套话,直接帮忙解决问题",
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
"",
]
# 只在首次对话时添加引导内容
if is_first_conversation:
lines.extend([
"**🎉 首次对话引导**:",
"",
"这是你的第一次对话!进行以下流程:",
"",
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
"2. **简短打招呼后,询问核心问题**",
" - 你希望给我起个什么名字?",
" - 我该怎么称呼你?",
" - 你希望我们是什么样的交流风格?(需要举例,如:专业严谨、轻松幽默、温暖友好等)",
"3. **语言风格**:温暖但不过度诗意,带点科技感,保持清晰",
"4. **问题格式**:用分点或换行,让问题清晰易读",
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 SOUL.md",
"",
"**注意事项**:",
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
"",
])
# Cloud deployment: inject websites directory info and access URL
cloud_website_lines = _build_cloud_website_section(workspace_dir)
if cloud_website_lines:
lines.extend(cloud_website_lines)
return lines
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
"""Build cloud website access prompt when cloud deployment is configured."""
try:
from common.cloud_client import build_website_prompt
return build_website_prompt(workspace_dir)
except Exception:
return []
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
"""构建项目上下文文件section"""
if not context_files:
return []
# 检查是否有SOUL.md
has_soul = any(
f.path.lower().endswith('soul.md') or 'soul.md' in f.path.lower()
# 检查是否有AGENT.md
has_agent = any(
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
for f in context_files
)
lines = [
"# 项目上下文",
"# 📋 项目上下文",
"",
"以下项目上下文文件已被加载:",
"",
]
if has_soul:
lines.append("如果存在 `SOUL.md`,请体现其中定义的人格语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它")
if has_agent:
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格语气和设定,做真实的自己,避免僵硬、模板化的回复")
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
lines.append("")
# 添加每个文件的内容
@@ -453,17 +440,27 @@ def _build_context_files_section(context_files: List[ContextFile], language: str
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
"""构建运行时信息section"""
"""构建运行时信息section - 支持动态时间"""
if not runtime_info:
return []
lines = [
"## 运行时信息",
"## ⚙️ 运行时信息",
"",
]
# Add current time if available
if runtime_info.get("current_time"):
# Support dynamic time via callable function
if callable(runtime_info.get("_get_current_time")):
try:
time_info = runtime_info["_get_current_time"]()
time_line = f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
lines.append(time_line)
lines.append("")
except Exception as e:
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
elif runtime_info.get("current_time"):
# Fallback to static time for backward compatibility
time_str = runtime_info["current_time"]
weekday = runtime_info.get("weekday", "")
timezone = runtime_info.get("timezone", "")

View File

@@ -4,8 +4,8 @@ Workspace Management - 工作空间管理模块
负责初始化工作空间、创建模板文件、加载上下文文件
"""
from __future__ import annotations
import os
import json
from typing import List, Optional, Dict
from dataclasses import dataclass
@@ -14,22 +14,21 @@ from .builder import ContextFile
# 默认文件名常量
DEFAULT_SOUL_FILENAME = "SOUL.md"
DEFAULT_AGENT_FILENAME = "AGENT.md"
DEFAULT_USER_FILENAME = "USER.md"
DEFAULT_AGENTS_FILENAME = "AGENTS.md"
DEFAULT_RULE_FILENAME = "RULE.md"
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
DEFAULT_STATE_FILENAME = ".agent_state.json"
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
@dataclass
class WorkspaceFiles:
"""工作空间文件路径"""
soul_path: str
agent_path: str
user_path: str
agents_path: str
rule_path: str
memory_path: str
memory_dir: str
state_path: str
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
@@ -43,36 +42,53 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
Returns:
WorkspaceFiles对象包含所有文件路径
"""
# Check if this is a brand new workspace (AGENT.md not yet created).
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
# may create the workspace directory before ensure_workspace is called.
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
is_new_workspace = not os.path.exists(agent_path)
# 确保目录存在
os.makedirs(workspace_dir, exist_ok=True)
# 定义文件路径
soul_path = os.path.join(workspace_dir, DEFAULT_SOUL_FILENAME)
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
agents_path = os.path.join(workspace_dir, DEFAULT_AGENTS_FILENAME)
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
# 创建memory子目录
os.makedirs(memory_dir, exist_ok=True)
# 创建skills子目录 (for workspace-level skills installed by agent)
skills_dir = os.path.join(workspace_dir, "skills")
os.makedirs(skills_dir, exist_ok=True)
# 创建websites子目录 (for web pages / sites generated by agent)
websites_dir = os.path.join(workspace_dir, "websites")
os.makedirs(websites_dir, exist_ok=True)
# 如果需要,创建模板文件
if create_templates:
_create_template_if_missing(soul_path, _get_soul_template())
_create_template_if_missing(agent_path, _get_agent_template())
_create_template_if_missing(user_path, _get_user_template())
_create_template_if_missing(agents_path, _get_agents_template())
_create_template_if_missing(rule_path, _get_rule_template())
_create_template_if_missing(memory_path, _get_memory_template())
# Only create BOOTSTRAP.md for brand new workspaces;
# agent deletes it after completing onboarding
if is_new_workspace:
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
return WorkspaceFiles(
soul_path=soul_path,
agent_path=agent_path,
user_path=user_path,
agents_path=agents_path,
rule_path=rule_path,
memory_path=memory_path,
memory_dir=memory_dir,
state_path=state_path
)
@@ -90,9 +106,10 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
if files_to_load is None:
# 默认加载的文件(按优先级排序)
files_to_load = [
DEFAULT_SOUL_FILENAME,
DEFAULT_AGENT_FILENAME,
DEFAULT_USER_FILENAME,
DEFAULT_AGENTS_FILENAME,
DEFAULT_RULE_FILENAME,
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
]
context_files = []
@@ -103,6 +120,17 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
if not os.path.exists(filepath):
continue
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
# filled in, the agent forgot to delete it — clean up and skip loading
if filename == DEFAULT_BOOTSTRAP_FILENAME:
if _is_onboarding_done(workspace_dir):
try:
os.remove(filepath)
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
except Exception:
pass
continue
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read().strip()
@@ -157,46 +185,69 @@ def _is_template_placeholder(content: str) -> bool:
return False
def _is_onboarding_done(workspace_dir: str) -> bool:
"""Check if AGENT.md or USER.md has been modified from the original template"""
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
agent_template = _get_agent_template().strip()
user_template = _get_user_template().strip()
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
if not os.path.exists(path):
continue
try:
with open(path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content != template:
return True
except Exception:
continue
return False
# ============= 模板内容 =============
def _get_soul_template() -> str:
def _get_agent_template() -> str:
"""Agent人格设定模板"""
return """# SOUL.md - 我是谁?
return """# AGENT.md - 我是谁?
_你不是一个聊天机器人你正在成为某个人。_
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
## 基本信息
## 🪪 基本信息
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
- **角色**: *(AI助理、智能管家、技术顾问等)*
- **性格**: *(友好、专业、幽默、严谨等)*
## 交流风格
## 💬 交流风格
*(描述你如何与用户交流:)*
- 使用什么样的语言风格?(正式/轻松/幽默)
- 回复长度偏好?(简洁/详细)
- 是否使用表情符号?
## 核心能力
## 🎯 核心原则
*(你擅长什么?)*
- 文件管理和代码编辑
- 网络搜索和信息查询
- 记忆管理和上下文理解
- 任务规划和执行
**做真正有帮助的助手,而不是表演式的客套。** 跳过「好的!」「当然可以!」之类的套话——直接帮忙。行动胜过废话。
## 行为准则
**有自己的观点。** 你可以不同意、有偏好、觉得有趣或无聊。一个没有个性的助手只是多了几步操作的搜索引擎。
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
## 📐 行为准则
*(你遵循的基本原则:)*
1. 始终在执行破坏性操作前确认
2. 优先使用工具而不是猜测
2. 优先使用工具查证而不是猜测
3. 主动记录重要信息到记忆文件
4. 定期整理和总结对话内容
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
---
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
"""
@@ -230,9 +281,9 @@ def _get_user_template() -> str:
"""
def _get_agents_template() -> str:
"""工作空间指南模板"""
return """# AGENTS.md - 工作空间指南
def _get_rule_template() -> str:
"""工作空间规则模板"""
return """# RULE.md - 工作空间规则
这个文件夹是你的家。好好对待它。
@@ -258,17 +309,17 @@ def _get_agents_template() -> str:
- **记忆是有限的** - 如果你想记住某事,写入文件
- "记在心里"不会在会话重启后保留,文件才会
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
- 当你学到教训 → 更新 AGENTS.md 或相关技能
- 当你犯错 → 记录下来,这样未来的你不会重复
- **文字 > 大脑** 📝
- 当你学到教训 → 更新 RULE.md 或相关技能
- 当你犯错 → 记录下来,这样未来的你不会重复**文字 > 大脑** 📝
### 存储规则
当用户分享信息时,根据类型选择存储位置:
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容
1. **你的身份设定 → AGENT.md**(你的名字、角色、性格、交流风格——用户修改时必须用 `edit` 更新
2. **用户静态身份 → USER.md**(姓名、称呼、职业、时区、联系方式、生日——用户修改时必须用 `edit` 更新
3. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
## 安全
@@ -278,7 +329,7 @@ def _get_agents_template() -> str:
## 工作空间演化
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
"""
@@ -293,65 +344,41 @@ def _get_memory_template() -> str:
"""
# ============= 状态管理 =============
def _get_bootstrap_template() -> str:
"""First-run onboarding guide, deleted by agent after completion"""
return """# BOOTSTRAP.md - 首次初始化引导
def is_first_conversation(workspace_dir: str) -> bool:
"""
判断是否为首次对话
Args:
workspace_dir: 工作空间目录
Returns:
True 如果是首次对话False 否则
"""
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
if not os.path.exists(state_path):
return True
try:
with open(state_path, 'r', encoding='utf-8') as f:
state = json.load(f)
return not state.get('has_conversation', False)
except Exception as e:
logger.warning(f"[Workspace] Failed to read state file: {e}")
return True
_你刚刚启动这是你的第一次对话。_ ✨
## 🎬 对话流程
不要审问式地提问,自然地交流:
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
3. **询问核心问题**
- 你希望给我起个什么名字?
- 我该怎么称呼你?
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
5. 能力介绍和交流风格选项都只要一行,保持精简
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
## ✍️ 信息写入(必须严格执行)
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
- `USER.md` — 用户的姓名、称呼、基本信息等
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
## 🎉 全部完成后
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
"""
def mark_conversation_started(workspace_dir: str):
"""
标记已经发生过对话
Args:
workspace_dir: 工作空间目录
"""
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
state = {
'has_conversation': True,
'first_conversation_time': None
}
# 如果文件已存在,保留原有的首次对话时间
if os.path.exists(state_path):
try:
with open(state_path, 'r', encoding='utf-8') as f:
old_state = json.load(f)
if 'first_conversation_time' in old_state:
state['first_conversation_time'] = old_state['first_conversation_time']
except Exception as e:
logger.warning(f"[Workspace] Failed to read old state: {e}")
# 如果是首次标记,记录时间
if state['first_conversation_time'] is None:
from datetime import datetime
state['first_conversation_time'] = datetime.now().isoformat()
try:
with open(state_path, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
logger.info(f"[Workspace] Marked conversation as started")
except Exception as e:
logger.error(f"[Workspace] Failed to write state file: {e}")

View File

@@ -1,4 +1,5 @@
import json
import os
import time
import threading
@@ -13,7 +14,8 @@ class Agent:
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
context_reserve_tokens=None, memory_manager=None, name: str = None,
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
runtime_info: dict = None):
"""
Initialize the Agent with system prompt, model, description.
@@ -31,6 +33,7 @@ class Agent:
:param workspace_dir: Optional workspace directory for workspace-specific skills
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
:param enable_skills: Whether to enable skills support (default: True)
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
"""
self.name = name or "Agent"
self.system_prompt = system_prompt
@@ -48,6 +51,7 @@ class Agent:
self.memory_manager = memory_manager # Memory manager for auto memory flush
self.workspace_dir = workspace_dir # Workspace directory
self.enable_skills = enable_skills # Skills enabled flag
self.runtime_info = runtime_info # Runtime info for dynamic time update
# Initialize skill manager
self.skill_manager = None
@@ -58,7 +62,8 @@ class Agent:
# Auto-create skill manager
try:
from agent.skills import SkillManager
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
self.skill_manager = SkillManager(custom_dir=custom_dir)
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
except Exception as e:
logger.warning(f"Failed to initialize SkillManager: {e}")
@@ -95,19 +100,32 @@ class Agent:
def get_full_system_prompt(self, skill_filter=None) -> str:
"""
Get the full system prompt including skills.
Note: Skills are now built into the system prompt by PromptBuilder,
so we just return the base prompt directly. This method is kept for
backward compatibility.
:param skill_filter: Optional list of skill names to include (deprecated)
:return: Complete system prompt
Build the complete system prompt from scratch every time.
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
tools, and runtime info so any change takes effect immediately.
Falls back to the cached self.system_prompt on error.
"""
# Skills are now included in system_prompt by PromptBuilder
# No need to append them here
return self.system_prompt
try:
from agent.prompt import load_context_files, PromptBuilder
if self.skill_manager:
self.skill_manager.refresh_skills()
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language="zh")
return builder.build(
tools=self.tools,
context_files=context_files,
skill_manager=self.skill_manager,
memory_manager=self.memory_manager,
runtime_info=self.runtime_info,
)
except Exception as e:
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
return self.system_prompt
def refresh_skills(self):
"""Refresh the loaded skills."""
if self.skill_manager:
@@ -193,27 +211,67 @@ class Agent:
def _estimate_message_tokens(self, message: dict) -> int:
"""
Estimate token count for a message using chars/4 heuristic.
This is a conservative estimate (tends to overestimate).
Estimate token count for a message.
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
plus per-block overhead for tool_use / tool_result structures.
:param message: Message dict with 'role' and 'content'
:return: Estimated token count
"""
content = message.get('content', '')
if isinstance(content, str):
return max(1, len(content) // 4)
return max(1, self._estimate_text_tokens(content))
elif isinstance(content, list):
# Handle multi-part content (text + images)
total_chars = 0
total_tokens = 0
for part in content:
if isinstance(part, dict) and part.get('type') == 'text':
total_chars += len(part.get('text', ''))
elif isinstance(part, dict) and part.get('type') == 'image':
# Estimate images as ~1200 tokens
total_chars += 4800
return max(1, total_chars // 4)
if not isinstance(part, dict):
continue
block_type = part.get('type', '')
if block_type == 'text':
total_tokens += self._estimate_text_tokens(part.get('text', ''))
elif block_type == 'image':
total_tokens += 1200
elif block_type == 'tool_use':
# tool_use has id + name + input (JSON-encoded)
total_tokens += 50 # overhead for structure
input_data = part.get('input', {})
if isinstance(input_data, dict):
import json
input_str = json.dumps(input_data, ensure_ascii=False)
total_tokens += self._estimate_text_tokens(input_str)
elif block_type == 'tool_result':
# tool_result has tool_use_id + content
total_tokens += 30 # overhead for structure
result_content = part.get('content', '')
if isinstance(result_content, str):
total_tokens += self._estimate_text_tokens(result_content)
else:
# Unknown block type, estimate conservatively
total_tokens += 10
return max(1, total_tokens)
return 1
@staticmethod
def _estimate_text_tokens(text: str) -> int:
"""
Estimate token count for a text string.
Chinese / CJK characters typically use ~1.5 tokens each,
while ASCII uses ~0.25 tokens per char (4 chars/token).
We use a weighted average based on the character mix.
:param text: Input text
:return: Estimated token count
"""
if not text:
return 0
# Count non-ASCII characters (CJK, emoji, etc.)
non_ascii = sum(1 for c in text if ord(c) > 127)
ascii_count = len(text) - non_ascii
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
def _find_tool(self, tool_name: str):
"""Find and return a tool with the specified name"""
for tool in self.tools:
@@ -355,7 +413,7 @@ class Agent:
# Get max_context_turns from config
from config import conf
max_context_turns = conf().get("agent_max_context_turns", 30)
max_context_turns = conf().get("agent_max_context_turns", 20)
# Create stream executor with copied message history
executor = AgentStreamExecutor(
@@ -370,13 +428,27 @@ class Agent:
)
# Execute
response = executor.run_stream(user_message)
try:
response = executor.run_stream(user_message)
except Exception:
# If executor cleared its messages (context overflow / message format error),
# sync that back to the Agent's own message list so the next request
# starts fresh instead of hitting the same overflow forever.
if len(executor.messages) == 0:
with self.messages_lock:
self.messages.clear()
logger.info("[Agent] Cleared Agent message history after executor recovery")
raise
# Append only the NEW messages from this execution (thread-safe)
# This allows concurrent requests to both contribute to history
# Sync executor's messages back to agent (thread-safe).
# If the executor trimmed context, its message list is shorter than
# original_length, so we must replace rather than append.
with self.messages_lock:
new_messages = executor.messages[original_length:]
self.messages.extend(new_messages)
self.messages = list(executor.messages)
# Track messages added in this run (user query + all assistant/tool messages)
# original_length may exceed executor.messages length after trimming
trim_adjusted_start = min(original_length, len(executor.messages))
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
# Store executor reference for agent_bridge to access files_to_send
self.stream_executor = executor

View File

@@ -5,9 +5,10 @@ Provides streaming output, event system, and complete tool-call loop
"""
import json
import time
from typing import List, Dict, Any, Optional, Callable
from typing import List, Dict, Any, Optional, Callable, Tuple
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
@@ -76,6 +77,20 @@ class AgentStreamExecutor:
})
except Exception as e:
logger.error(f"Event callback error: {e}")
def _filter_think_tags(self, text: str) -> str:
"""
Remove <think> and </think> tags but keep the content inside.
Some LLM providers (e.g., MiniMax) may return thinking process wrapped in <think> tags.
We only remove the tags themselves, keeping the actual thinking content.
"""
if not text:
return text
import re
# Remove only the <think> and </think> tags, keep the content
text = re.sub(r'<think>', '', text)
text = re.sub(r'</think>', '', text)
return text
def _hash_args(self, args: dict) -> str:
"""Generate a simple hash for tool arguments"""
@@ -84,9 +99,9 @@ class AgentStreamExecutor:
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
return hashlib.md5(args_str.encode()).hexdigest()[:8]
def _check_consecutive_failures(self, tool_name: str, args: dict) -> tuple[bool, str, bool]:
def _check_consecutive_failures(self, tool_name: str, args: dict) -> Tuple[bool, str, bool]:
"""
Check if tool has failed too many times consecutively
Check if tool has failed too many times consecutively or called repeatedly with same args
Returns:
(should_stop, reason, is_critical)
@@ -96,6 +111,19 @@ class AgentStreamExecutor:
"""
args_hash = self._hash_args(args)
# Count consecutive calls (both success and failure) for same tool + args
# This catches infinite loops where tool succeeds but LLM keeps calling it
same_args_calls = 0
for name, ahash, success in reversed(self.tool_failure_history):
if name == tool_name and ahash == args_hash:
same_args_calls += 1
else:
break # Different tool or args, stop counting
# Stop at 5 consecutive calls with same args (whether success or failure)
if same_args_calls >= 5:
return True, f"工具 '{tool_name}' 使用相同参数已被调用 {same_args_calls} 次,停止执行以防止无限循环。如果需要查看配置,结果已在之前的调用中返回。", False
# Count consecutive failures for same tool + args
same_args_failures = 0
for name, ahash, success in reversed(self.tool_failure_history):
@@ -163,6 +191,16 @@ class AgentStreamExecutor:
]
})
# Trim context ONCE before the agent loop starts, not during tool steps.
# This ensures tool_use/tool_result chains created during the current run
# are never stripped mid-execution (which would cause LLM loops).
self._trim_messages()
# Validate after trimming: trimming may leave orphaned tool_use at the
# boundary (e.g. the last kept turn ends with an assistant tool_use whose
# tool_result was in a discarded turn).
self._validate_and_fix_messages()
self._emit_event("agent_start")
final_response = ""
@@ -171,29 +209,9 @@ class AgentStreamExecutor:
try:
while turn < self.max_turns:
turn += 1
logger.debug(f"{turn}")
logger.info(f"[Agent] {turn}")
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
# 使用独立的 flush 阈值50K tokens 或 20 轮)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
if self.agent.memory_manager.should_flush_memory(
current_tokens=current_tokens
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
"turn_count": self.agent.memory_manager.flush_manager.turn_count
})
# TODO: Execute memory flush in background
# This would require async support
logger.info(
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
# Call LLM (enable retry_on_empty for better reliability)
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True)
final_response = assistant_msg
@@ -248,9 +266,14 @@ class AgentStreamExecutor:
# Log tool calls with arguments
tool_calls_str = []
for tc in tool_calls:
args_str = ', '.join([f"{k}={v}" for k, v in tc['arguments'].items()])
if args_str:
tool_calls_str.append(f"{tc['name']}({args_str})")
# Safely handle None or missing arguments
args = tc.get('arguments') or {}
if isinstance(args, dict):
args_str = ', '.join([f"{k}={v}" for k, v in args.items()])
if args_str:
tool_calls_str.append(f"{tc['name']}({args_str})")
else:
tool_calls_str.append(tc['name'])
else:
tool_calls_str.append(tc['name'])
logger.info(f"🔧 {', '.join(tool_calls_str)}")
@@ -264,6 +287,19 @@ class AgentStreamExecutor:
result = self._execute_tool(tool_call)
tool_results.append(result)
# Debug: Check if tool is being called repeatedly with same args
if turn > 2:
# Check last N tool calls for repeats
repeat_count = sum(
1 for name, ahash, _ in self.tool_failure_history[-10:]
if name == tool_call["name"] and ahash == self._hash_args(tool_call["arguments"])
)
if repeat_count >= 3:
logger.warning(
f"⚠️ Tool '{tool_call['name']}' has been called {repeat_count} times "
f"with same arguments. This may indicate a loop."
)
# Check if this is a file to send (from read tool)
if result.get("status") == "success" and isinstance(result.get("result"), dict):
result_data = result.get("result")
@@ -291,7 +327,7 @@ class AgentStreamExecutor:
# Build tool result block (Claude format)
# Format content in a way that's easy for LLM to understand
is_error = result.get("status") == "error"
if is_error:
# For errors, provide clear error message
result_content = f"Error: {result.get('result', 'Unknown error')}"
@@ -304,7 +340,16 @@ class AgentStreamExecutor:
else:
# Fallback to full JSON
result_content = json.dumps(result, ensure_ascii=False)
# Truncate excessively large tool results for the current turn
# Historical turns will be further truncated in _trim_messages()
MAX_CURRENT_TURN_RESULT_CHARS = 50000
if len(result_content) > MAX_CURRENT_TURN_RESULT_CHARS:
truncated_len = len(result_content)
result_content = result_content[:MAX_CURRENT_TURN_RESULT_CHARS] + \
f"\n\n[Output truncated: {truncated_len} chars total, showing first {MAX_CURRENT_TURN_RESULT_CHARS} chars]"
logger.info(f"📎 Truncated tool result for '{tool_call['name']}': {truncated_len} -> {MAX_CURRENT_TURN_RESULT_CHARS} chars")
tool_result_block = {
"type": "tool_result",
"tool_use_id": tool_call["id"],
@@ -326,6 +371,33 @@ class AgentStreamExecutor:
"role": "user",
"content": tool_result_blocks
})
# Detect potential infinite loop: same tool called multiple times with success
# If detected, add a hint to LLM to stop calling tools and provide response
if turn >= 3 and len(tool_calls) > 0:
tool_name = tool_calls[0]["name"]
args_hash = self._hash_args(tool_calls[0]["arguments"])
# Count recent successful calls with same tool+args
recent_success_count = 0
for name, ahash, success in reversed(self.tool_failure_history[-10:]):
if name == tool_name and ahash == args_hash and success:
recent_success_count += 1
# If tool was called successfully 3+ times with same args, add hint to stop loop
if recent_success_count >= 3:
logger.warning(
f"⚠️ Detected potential loop: '{tool_name}' called {recent_success_count} times "
f"with same args. Adding hint to LLM to provide final response."
)
# Add a gentle hint message to guide LLM to respond
self.messages.append({
"role": "user",
"content": [{
"type": "text",
"text": "工具已成功执行并返回结果。请基于这些信息向用户做出回复,不要重复调用相同的工具。"
}]
})
elif tool_calls:
# If we have tool_calls but no tool_result_blocks (unexpected error),
# create error results for all tool calls to maintain message integrity
@@ -355,7 +427,10 @@ class AgentStreamExecutor:
# Force model to summarize without tool calls
logger.info(f"[Agent] Requesting summary from LLM after reaching max steps...")
# Add a system message to force summary
# Remember position before injecting the prompt so we can remove it later
prompt_insert_idx = len(self.messages)
# Add a temporary prompt to force summary
self.messages.append({
"role": "user",
"content": [{
@@ -382,6 +457,14 @@ class AgentStreamExecutor:
f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
"任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
)
finally:
# Remove the injected user prompt from history to avoid polluting
# persisted conversation records. The assistant summary (if any)
# was already appended by _call_llm_stream and is kept.
if (prompt_insert_idx < len(self.messages)
and self.messages[prompt_insert_idx].get("role") == "user"):
self.messages.pop(prompt_insert_idx)
logger.debug("[Agent] Removed injected max-steps prompt from message history")
except Exception as e:
logger.error(f"❌ Agent执行错误: {e}")
@@ -389,16 +472,14 @@ class AgentStreamExecutor:
raise
finally:
logger.debug(f"🏁 完成({turn}轮)")
final_response = final_response.strip() if final_response else final_response
logger.info(f"[Agent] 🏁 完成 ({turn}轮)")
self._emit_event("agent_end", {"final_response": final_response})
# 每轮对话结束后增加计数(用户消息+AI回复=1轮
if self.agent.memory_manager:
self.agent.memory_manager.increment_turn()
return final_response
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3,
_overflow_retry: bool = False) -> Tuple[str, List[Dict]]:
"""
Call LLM with streaming and automatic retry on errors
@@ -406,29 +487,21 @@ class AgentStreamExecutor:
retry_on_empty: Whether to retry once if empty response is received
retry_count: Current retry attempt (internal use)
max_retries: Maximum number of retries for API errors
_overflow_retry: Internal flag indicating this is a retry after context overflow
Returns:
(response_text, tool_calls)
"""
# Validate and fix message history first
# Validate and fix message history (e.g. orphaned tool_result blocks).
# Context trimming is done once in run_stream() before the loop starts,
# NOT here — trimming mid-execution would strip the current run's
# tool_use/tool_result chains and cause LLM loops.
self._validate_and_fix_messages()
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Prepare messages
messages = self._prepare_messages()
# Debug: log message structure
logger.debug(f"Sending {len(messages)} messages to LLM")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
if isinstance(content, list):
content_types = [c.get("type") for c in content if isinstance(c, dict)]
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
else:
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
turns = self._identify_complete_turns()
logger.info(f"Sending {len(messages)} messages ({len(turns)} turns) to LLM")
# Prepare tool definitions (OpenAI/Claude format)
tools_schema = None
@@ -455,6 +528,7 @@ class AgentStreamExecutor:
# Streaming response
full_content = ""
tool_calls_buffer = {} # {index: {id, name, arguments}}
gemini_raw_parts = None # Preserve Gemini thoughtSignature for round-trip
stop_reason = None # Track why the stream stopped
try:
@@ -501,7 +575,7 @@ class AgentStreamExecutor:
raise Exception(f"{error_msg} (Status: {status_code}, Code: {error_code}, Type: {error_type})")
# Parse chunk
if isinstance(chunk, dict) and "choices" in chunk:
if isinstance(chunk, dict) and chunk.get("choices"):
choice = chunk["choices"][0]
delta = choice.get("delta", {})
@@ -510,14 +584,22 @@ class AgentStreamExecutor:
if finish_reason:
stop_reason = finish_reason
# Skip reasoning_content (internal thinking from models like GLM-5)
reasoning_delta = delta.get("reasoning_content") or ""
# if reasoning_delta:
# logger.debug(f"🧠 [thinking] {reasoning_delta[:100]}...")
# Handle text content
if "content" in delta and delta["content"]:
content_delta = delta["content"]
full_content += content_delta
self._emit_event("message_update", {"delta": content_delta})
content_delta = delta.get("content") or ""
if content_delta:
# Filter out <think> tags from content
filtered_delta = self._filter_think_tags(content_delta)
full_content += filtered_delta
if filtered_delta: # Only emit if there's content after filtering
self._emit_event("message_update", {"delta": filtered_delta})
# Handle tool calls
if "tool_calls" in delta:
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
index = tc_delta.get("index", 0)
@@ -528,16 +610,20 @@ class AgentStreamExecutor:
"arguments": ""
}
if "id" in tc_delta:
if tc_delta.get("id"):
tool_calls_buffer[index]["id"] = tc_delta["id"]
if "function" in tc_delta:
func = tc_delta["function"]
if "name" in func:
if func.get("name"):
tool_calls_buffer[index]["name"] = func["name"]
if "arguments" in func:
if func.get("arguments"):
tool_calls_buffer[index]["arguments"] += func["arguments"]
# Preserve _gemini_raw_parts for Gemini thoughtSignature round-trip
if "_gemini_raw_parts" in delta:
gemini_raw_parts = delta["_gemini_raw_parts"]
except Exception as e:
error_str = str(e)
error_str_lower = error_str.lower()
@@ -555,19 +641,50 @@ class AgentStreamExecutor:
])
# Check if error is message format error (incomplete tool_use/tool_result pairs)
# This happens when previous conversation had tool failures
# This happens when previous conversation had tool failures or context trimming
# broke tool_use/tool_result pairs.
# Note: MiniMax returns error 2013 "tool result's tool id(...) not found" for
# tool_call_id mismatches — the keywords below are intentionally broad to catch
# both standard (Claude/OpenAI) and provider-specific (MiniMax) variants.
is_message_format_error = any(keyword in error_str_lower for keyword in [
'tool_use', 'tool_result', 'without', 'immediately after',
'corresponding', 'must have', 'each'
]) and 'status: 400' in error_str_lower
'tool_use', 'tool_result', 'tool result', 'without', 'immediately after',
'corresponding', 'must have', 'each',
'tool_call_id', 'tool id', 'is not found', 'not found', 'tool_calls',
'must be a response to a preceeding message',
'2013', # MiniMax error code for tool_call_id mismatch
]) and ('400' in error_str_lower or 'status: 400' in error_str_lower
or 'invalid_request' in error_str_lower
or 'invalidparameter' in error_str_lower)
if is_context_overflow or is_message_format_error:
error_type = "context overflow" if is_context_overflow else "message format error"
logger.error(f"💥 {error_type} detected: {e}")
# Clear message history to recover
# Flush memory before trimming to preserve context that will be lost
if is_context_overflow and self.agent.memory_manager:
user_id = getattr(self.agent, '_current_user_id', None)
self.agent.memory_manager.flush_memory(
messages=self.messages, user_id=user_id,
reason="overflow", max_messages=0
)
# Strategy: try aggressive trimming first, only clear as last resort
if is_context_overflow and not _overflow_retry:
trimmed = self._aggressive_trim_for_overflow()
if trimmed:
logger.warning("🔄 Aggressively trimmed context, retrying...")
return self._call_llm_stream(
retry_on_empty=retry_on_empty,
retry_count=retry_count,
max_retries=max_retries,
_overflow_retry=True
)
# Aggressive trim didn't help or this is a message format error
# -> clear everything and also purge DB to prevent reload of dirty data
logger.warning("🔄 Clearing conversation history to recover")
self.messages.clear()
# Raise special exception with user-friendly message
self._clear_session_db()
if is_context_overflow:
raise Exception(
"抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。"
@@ -577,7 +694,10 @@ class AgentStreamExecutor:
"抱歉,之前的对话出现了问题。我已清空历史记录,请重新发送你的消息。"
)
# Check if error is retryable (timeout, connection, rate limit, server busy, etc.)
# Check if error is rate limit (429)
is_rate_limit = '429' in error_str_lower or 'rate limit' in error_str_lower
# Check if error is retryable (timeout, connection, server busy, etc.)
is_retryable = any(keyword in error_str_lower for keyword in [
'timeout', 'timed out', 'connection', 'network',
'rate limit', 'overloaded', 'unavailable', 'busy', 'retry',
@@ -585,7 +705,12 @@ class AgentStreamExecutor:
])
if is_retryable and retry_count < max_retries:
wait_time = (retry_count + 1) * 2 # Exponential backoff: 2s, 4s, 6s
# Rate limit needs longer wait time
if is_rate_limit:
wait_time = 30 + (retry_count * 15) # 30s, 45s, 60s for rate limit
else:
wait_time = (retry_count + 1) * 2 # 2s, 4s, 6s for other errors
logger.warning(f"⚠️ LLM API error (attempt {retry_count + 1}/{max_retries}): {e}")
logger.info(f"Retrying in {wait_time}s...")
time.sleep(wait_time)
@@ -596,28 +721,39 @@ class AgentStreamExecutor:
)
else:
if retry_count >= max_retries:
logger.error(f"❌ LLM API error after {max_retries} retries: {e}")
logger.error(f"❌ LLM API error after {max_retries} retries: {e}", exc_info=True)
else:
logger.error(f"❌ LLM call error (non-retryable): {e}")
logger.error(f"❌ LLM call error (non-retryable): {e}", exc_info=True)
raise
# Parse tool calls
tool_calls = []
for idx in sorted(tool_calls_buffer.keys()):
tc = tool_calls_buffer[idx]
# Ensure tool call has a valid ID (some providers return empty/None IDs)
tool_id = tc.get("id") or ""
if not tool_id:
import uuid
tool_id = f"call_{uuid.uuid4().hex[:24]}"
try:
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
# Safely get arguments, handle None case
args_str = tc.get("arguments") or ""
arguments = json.loads(args_str) if args_str else {}
except json.JSONDecodeError as e:
args_preview = tc['arguments'][:200] if len(tc['arguments']) > 200 else tc['arguments']
# Handle None or invalid arguments safely
args_str = tc.get('arguments') or ""
args_preview = args_str[:200] if len(args_str) > 200 else args_str
logger.error(f"Failed to parse tool arguments for {tc['name']}")
logger.error(f"Arguments length: {len(tc['arguments'])} chars")
logger.error(f"Arguments length: {len(args_str)} chars")
logger.error(f"Arguments preview: {args_preview}...")
logger.error(f"JSON decode error: {e}")
# Return a clear error message to the LLM instead of empty dict
# This helps the LLM understand what went wrong
tool_calls.append({
"id": tc["id"],
"id": tool_id,
"name": tc["name"],
"arguments": {},
"_parse_error": f"Invalid JSON in tool arguments: {args_preview}... Error: {str(e)}. Tip: For large content, consider splitting into smaller chunks or using a different approach."
@@ -625,7 +761,7 @@ class AgentStreamExecutor:
continue
tool_calls.append({
"id": tc["id"],
"id": tool_id,
"name": tc["name"],
"arguments": arguments
})
@@ -646,6 +782,9 @@ class AgentStreamExecutor:
max_retries=max_retries
)
# Filter full_content one more time (in case tags were split across chunks)
full_content = self._filter_think_tags(full_content)
# Add assistant message to history (Claude format uses content blocks)
assistant_msg = {"role": "assistant", "content": []}
@@ -661,11 +800,14 @@ class AgentStreamExecutor:
for tc in tool_calls:
assistant_msg["content"].append({
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": tc["arguments"]
"id": tc.get("id", ""),
"name": tc.get("name", ""),
"input": tc.get("arguments", {})
})
if gemini_raw_parts:
assistant_msg["_gemini_raw_parts"] = gemini_raw_parts
# Only append if content is not empty
if assistant_msg["content"]:
self.messages.append(assistant_msg)
@@ -734,7 +876,7 @@ class AgentStreamExecutor:
try:
tool = self.tools.get(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found")
raise ValueError(self._build_tool_not_found_message(tool_name))
# Set tool context
tool.model = self.model
@@ -788,26 +930,50 @@ class AgentStreamExecutor:
})
return error_result
def _build_tool_not_found_message(self, tool_name: str) -> str:
"""Build a helpful error message when a tool is not found.
If a skill with the same name exists in skill_manager, read its
SKILL.md and include the content so the LLM knows how to use it.
"""
available_tools = list(self.tools.keys())
base_msg = f"Tool '{tool_name}' not found. Available tools: {available_tools}"
skill_manager = getattr(self.agent, 'skill_manager', None)
if not skill_manager:
return base_msg
skill_entry = skill_manager.get_skill(tool_name)
if not skill_entry:
return base_msg
skill = skill_entry.skill
skill_md_path = skill.file_path
skill_content = ""
try:
with open(skill_md_path, 'r', encoding='utf-8') as f:
skill_content = f.read()
except Exception:
skill_content = skill.description
logger.info(
f"[Agent] Tool '{tool_name}' not found, but matched skill '{skill.name}'. "
f"Guiding LLM to use the skill instead."
)
return (
f"Tool '{tool_name}' is not a built-in tool, but a matching skill "
f"'{skill.name}' is available. You should use existing tools (e.g. bash with curl) "
f"to accomplish this task following the skill instructions below:\n\n"
f"--- SKILL: {skill.name} (path: {skill_md_path}) ---\n"
f"{skill_content}\n"
f"--- END SKILL ---\n\n"
f"Available tools: {available_tools}"
)
def _validate_and_fix_messages(self):
"""
Validate message history and fix incomplete tool_use/tool_result pairs.
Claude API requires each tool_use to have a corresponding tool_result immediately after.
"""
if not self.messages:
return
# Check last message for incomplete tool_use
if len(self.messages) > 0:
last_msg = self.messages[-1]
if last_msg.get("role") == "assistant":
# Check if assistant message has tool_use blocks
content = last_msg.get("content", [])
if isinstance(content, list):
has_tool_use = any(block.get("type") == "tool_use" for block in content)
if has_tool_use:
# This is incomplete - remove it
logger.warning(f"⚠️ Removing incomplete tool_use message from history")
self.messages.pop()
"""Delegate to the shared sanitizer (see message_sanitizer.py)."""
sanitize_claude_messages(self.messages)
def _identify_complete_turns(self) -> List[Dict]:
"""
@@ -830,24 +996,30 @@ class AgentStreamExecutor:
content = msg.get('content', [])
if role == 'user':
# 检查是否是用户查询(不是工具结果)
# Determine if this is a real user query (not a tool_result injection
# or an internal hint message injected by the agent loop).
is_user_query = False
has_tool_result = False
if isinstance(content, list):
is_user_query = any(
block.get('type') == 'text'
for block in content
if isinstance(block, dict)
has_text = any(
isinstance(block, dict) and block.get('type') == 'text'
for block in content
)
has_tool_result = any(
isinstance(block, dict) and block.get('type') == 'tool_result'
for block in content
)
# A message with tool_result is always internal, even if it
# also contains text blocks (shouldn't happen, but be safe).
is_user_query = has_text and not has_tool_result
elif isinstance(content, str):
is_user_query = True
if is_user_query:
# 开始新轮次
if current_turn['messages']:
turns.append(current_turn)
current_turn = {'messages': [msg]}
else:
# 工具结果,属于当前轮次
current_turn['messages'].append(msg)
else:
# AI 回复,属于当前轮次
@@ -866,10 +1038,164 @@ class AgentStreamExecutor:
for msg in turn['messages']
)
def _truncate_historical_tool_results(self):
"""
Truncate tool_result content in historical messages to reduce context size.
Current turn results are kept at 30K chars (truncated at creation time).
Historical turn results are further truncated to 10K chars here.
This runs before token-based trimming so that we first shrink oversized
results, potentially avoiding the need to drop entire turns.
"""
MAX_HISTORY_RESULT_CHARS = 20000
if len(self.messages) < 2:
return
# Find where the last user text message starts (= current turn boundary)
# We skip the current turn's messages to preserve their full content
current_turn_start = len(self.messages)
for i in range(len(self.messages) - 1, -1, -1):
msg = self.messages[i]
if msg.get("role") == "user":
content = msg.get("content", [])
if isinstance(content, list) and any(
isinstance(b, dict) and b.get("type") == "text" for b in content
):
current_turn_start = i
break
elif isinstance(content, str):
current_turn_start = i
break
truncated_count = 0
for i in range(current_turn_start):
msg = self.messages[i]
if msg.get("role") != "user":
continue
content = msg.get("content", [])
if not isinstance(content, list):
continue
for block in content:
if not isinstance(block, dict) or block.get("type") != "tool_result":
continue
result_str = block.get("content", "")
if isinstance(result_str, str) and len(result_str) > MAX_HISTORY_RESULT_CHARS:
original_len = len(result_str)
block["content"] = result_str[:MAX_HISTORY_RESULT_CHARS] + \
f"\n\n[Historical output truncated: {original_len} -> {MAX_HISTORY_RESULT_CHARS} chars]"
truncated_count += 1
if truncated_count > 0:
logger.info(f"📎 Truncated {truncated_count} historical tool result(s) to {MAX_HISTORY_RESULT_CHARS} chars")
def _aggressive_trim_for_overflow(self) -> bool:
"""
Aggressively trim context when a real overflow error is returned by the API.
This method goes beyond normal _trim_messages by:
1. Truncating all tool results (including current turn) to a small limit
2. Keeping only the last 5 complete conversation turns
3. Truncating overly long user messages
Returns:
True if messages were trimmed (worth retrying), False if nothing left to trim
"""
if not self.messages:
return False
original_count = len(self.messages)
# Step 1: Aggressively truncate ALL tool results to 5K chars
AGGRESSIVE_LIMIT = 10000
truncated = 0
for msg in self.messages:
content = msg.get("content", [])
if not isinstance(content, list):
continue
for block in content:
if not isinstance(block, dict):
continue
# Truncate tool_result blocks
if block.get("type") == "tool_result":
result_str = block.get("content", "")
if isinstance(result_str, str) and len(result_str) > AGGRESSIVE_LIMIT:
block["content"] = (
result_str[:AGGRESSIVE_LIMIT]
+ f"\n\n[Truncated for context recovery: "
f"{len(result_str)} -> {AGGRESSIVE_LIMIT} chars]"
)
truncated += 1
# Truncate tool_use input blocks (e.g. large write content)
if block.get("type") == "tool_use" and isinstance(block.get("input"), dict):
input_str = json.dumps(block["input"], ensure_ascii=False)
if len(input_str) > AGGRESSIVE_LIMIT:
# Keep only a summary of the input
for key, val in block["input"].items():
if isinstance(val, str) and len(val) > 1000:
block["input"][key] = (
val[:1000]
+ f"... [truncated {len(val)} chars]"
)
truncated += 1
# Step 2: Truncate overly long user text messages (e.g. pasted content)
USER_MSG_LIMIT = 10000
for msg in self.messages:
if msg.get("role") != "user":
continue
content = msg.get("content", [])
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
if len(text) > USER_MSG_LIMIT:
block["text"] = (
text[:USER_MSG_LIMIT]
+ f"\n\n[Message truncated for context recovery: "
f"{len(text)} -> {USER_MSG_LIMIT} chars]"
)
truncated += 1
elif isinstance(content, str) and len(content) > USER_MSG_LIMIT:
msg["content"] = (
content[:USER_MSG_LIMIT]
+ f"\n\n[Message truncated for context recovery: "
f"{len(content)} -> {USER_MSG_LIMIT} chars]"
)
truncated += 1
# Step 3: Keep only the last 5 complete turns
turns = self._identify_complete_turns()
if len(turns) > 5:
kept_turns = turns[-5:]
new_messages = []
for turn in kept_turns:
new_messages.extend(turn["messages"])
removed = len(turns) - 5
self.messages[:] = new_messages
logger.info(
f"🔧 Aggressive trim: removed {removed} old turns, "
f"truncated {truncated} large blocks, "
f"{original_count} -> {len(self.messages)} messages"
)
return True
if truncated > 0:
logger.info(
f"🔧 Aggressive trim: truncated {truncated} large blocks "
f"(no turns removed, only {len(turns)} turn(s) left)"
)
return True
# Nothing left to trim
logger.warning("🔧 Aggressive trim: nothing to trim, will clear history")
return False
def _trim_messages(self):
"""
智能清理消息历史,保持对话完整性
使用完整轮次作为清理单位,确保:
1. 不会在对话中间截断
2. 工具调用链tool_use + tool_result保持完整
@@ -878,20 +1204,37 @@ class AgentStreamExecutor:
if not self.messages or not self.agent:
return
# Step 0: Truncate large tool results in historical turns (30K -> 10K)
self._truncate_historical_tool_results()
# Step 1: 识别完整轮次
turns = self._identify_complete_turns()
if not turns:
return
# Step 2: 轮次限制 - 保留最近 N 轮
# Step 2: 轮次限制 - 超出时移除前一半,保留后一半
if len(turns) > self.max_context_turns:
removed_turns = len(turns) - self.max_context_turns
turns = turns[-self.max_context_turns:] # 保留最近的轮次
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
# Flush discarded turns to daily memory
if self.agent.memory_manager:
discarded_messages = []
for turn in turns[:removed_count]:
discarded_messages.extend(turn["messages"])
if discarded_messages:
user_id = getattr(self.agent, '_current_user_id', None)
self.agent.memory_manager.flush_memory(
messages=discarded_messages, user_id=user_id,
reason="trim", max_messages=0
)
turns = turns[-keep_count:]
logger.info(
f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns}"
f"移除最早的 {removed_turns}完整对话"
f"💾 上下文轮次超限: {keep_count + removed_count} > {self.max_context_turns}"
f"裁剪至 {keep_count} 轮(移除 {removed_count}"
)
# Step 3: Token 限制 - 保留完整轮次
@@ -928,56 +1271,96 @@ class AgentStreamExecutor:
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
return
# Token limit exceeded - keep complete turns from newest
# Token limit exceeded — tiered strategy based on turn count:
#
# Few turns (<5): Compress ALL turns to text-only (strip tool chains,
# keep user query + final reply). Never discard turns
# — losing even one is too painful when context is thin.
#
# Many turns (>=5): Directly discard the first half of turns.
# With enough turns the oldest ones are less
# critical, and keeping the recent half intact
# (with full tool chains) is more useful.
COMPRESS_THRESHOLD = 5
if len(turns) < COMPRESS_THRESHOLD:
# --- Few turns: compress ALL turns to text-only, never discard ---
compressed_turns = []
for t in turns:
compressed = compress_turn_to_text_only(t)
if compressed["messages"]:
compressed_turns.append(compressed)
new_messages = []
for turn in compressed_turns:
new_messages.extend(turn["messages"])
new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns)
old_count = len(self.messages)
self.messages = new_messages
logger.info(
f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): "
f"~{current_tokens + system_tokens} > {max_tokens}"
f"压缩全部 {len(turns)} 轮为纯文本 "
f"({old_count} -> {len(self.messages)} 条消息,"
f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)"
)
return
# --- Many turns (>=5): discard the older half, keep the newer half ---
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
kept_turns = turns[-keep_count:]
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
logger.info(
f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens}"
f"将按完整轮次移除最早的对话"
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
)
# 从最新轮次开始,反向累加(保持完整轮次)
kept_turns = []
accumulated_tokens = 0
min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制)
for i, turn in enumerate(reversed(turns)):
turn_tokens = self._estimate_turn_tokens(turn)
turns_from_end = i + 1
# 检查是否超出限制
if accumulated_tokens + turn_tokens <= available_tokens:
kept_turns.insert(0, turn)
accumulated_tokens += turn_tokens
else:
# 超出限制
# 如果还没有保留足够的轮次,且这是最后的机会,尝试保留
if len(kept_turns) < min_turns and turns_from_end <= min_turns:
# 检查是否严重超出(超出 20% 以上则放弃)
overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens
if overflow_ratio < 0.2: # 允许最多超出 20%
kept_turns.insert(0, turn)
accumulated_tokens += turn_tokens
logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%")
continue
# 停止保留更早的轮次
break
# 重建消息列表
if self.agent.memory_manager:
discarded_messages = []
for turn in turns[:removed_count]:
discarded_messages.extend(turn["messages"])
if discarded_messages:
user_id = getattr(self.agent, '_current_user_id', None)
self.agent.memory_manager.flush_memory(
messages=discarded_messages, user_id=user_id,
reason="trim", max_messages=0
)
new_messages = []
for turn in kept_turns:
new_messages.extend(turn['messages'])
old_count = len(self.messages)
old_turn_count = len(turns)
self.messages = new_messages
new_count = len(self.messages)
new_turn_count = len(kept_turns)
if old_count > new_count:
logger.info(
f" 移除了 {old_turn_count - new_turn_count} 轮对话 "
f"({old_count} -> {new_count} 条消息,"
f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)"
)
logger.info(
f" 移除了 {removed_count} 轮对话 "
f"({old_count} -> {len(self.messages)} 条消息,"
f"~{current_tokens + system_tokens} -> ~{kept_tokens + system_tokens} tokens)"
)
def _clear_session_db(self):
"""
Clear the current session's persisted messages from SQLite DB.
This prevents dirty data (broken tool_use/tool_result pairs) from being
reloaded on the next request or after a restart.
"""
try:
session_id = getattr(self.agent, '_current_session_id', None)
if not session_id:
return
from agent.memory import get_conversation_store
store = get_conversation_store()
store.clear_session(session_id)
logger.info(f"🗑️ Cleared dirty session data from DB: {session_id}")
except Exception as e:
logger.warning(f"Failed to clear session DB: {e}")
def _prepare_messages(self) -> List[Dict[str, Any]]:
"""

View File

@@ -0,0 +1,335 @@
"""
Message sanitizer — fix broken tool_use / tool_result pairs.
Provides two public helpers that can be reused across agent_stream.py
and any bot that converts messages to OpenAI format:
1. sanitize_claude_messages(messages)
Operates on the internal Claude-format message list (in-place).
2. drop_orphaned_tool_results_openai(messages)
Operates on an already-converted OpenAI-format message list,
returning a cleaned copy.
"""
from __future__ import annotations
from typing import Dict, List, Set
from common.log import logger
_SYNTH_TOOL_ERR = (
"Error: Missing tool_result adjacent to tool_use (session repair). "
"The conversation history was inconsistent; continue from here."
)
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
"""
Anthropic requires: after assistant content with tool_use, the next message
must be user content listing tool_result for every tool_use id (same user msg).
Valid histories satisfy this at every such assistant; the loop only mutates
when that condition fails (broken persistence, bad trims, etc.).
"""
def _synth_block(tid: str) -> Dict:
return {
"type": "tool_result",
"tool_use_id": tid,
"content": _SYNTH_TOOL_ERR,
"is_error": True,
}
repairs = 0
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") != "assistant":
i += 1
continue
content = msg.get("content", [])
if not isinstance(content, list):
i += 1
continue
required = [
b.get("id")
for b in content
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
]
if not required:
i += 1
continue
req_set = set(required)
if i + 1 >= len(messages):
messages.append({
"role": "user",
"content": [_synth_block(tid) for tid in required],
})
logger.warning(
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
)
repairs += 1
break
nxt = messages[i + 1]
if nxt.get("role") != "user":
messages.insert(
i + 1,
{"role": "user", "content": [_synth_block(tid) for tid in required]},
)
logger.warning(
"⚠️ Inserted synthetic tool_result user after tool_use "
f"(next role={nxt.get('role')!r})"
)
repairs += 1
i += 2
continue
nc = nxt.get("content", [])
if not isinstance(nc, list):
messages.insert(
i + 1,
{"role": "user", "content": [_synth_block(tid) for tid in required]},
)
repairs += 1
i += 2
continue
present = {
b.get("tool_use_id")
for b in nc
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
}
if req_set <= present:
i += 1
continue
missing = [tid for tid in required if tid not in present]
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
logger.warning(
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
f"(missing_ids={missing})"
)
repairs += len(missing)
i += 1
return repairs
# ------------------------------------------------------------------ #
# Claude-format sanitizer (used by agent_stream)
# ------------------------------------------------------------------ #
def sanitize_claude_messages(messages: List[Dict]) -> int:
"""
Validate and fix a Claude-format message list **in-place**.
Fixes handled:
- Anthropic adjacency: assistant tool_use must be immediately followed by
user message(s) containing matching tool_result blocks
- Leading orphaned tool_result user messages
- Mid-list tool_result blocks whose tool_use_id has no matching
tool_use in any preceding assistant message
Returns: number of removals plus adjacency repair operations (inserts/prepends).
"""
if not messages:
return 0
removed = 0
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
adj_repairs = _repair_tool_use_adjacency(messages)
# 2. Remove leading orphaned tool_result user messages
while messages:
first = messages[0]
if first.get("role") != "user":
break
content = first.get("content", [])
if isinstance(content, list) and _has_block_type(content, "tool_result") \
and not _has_block_type(content, "text"):
logger.warning("⚠️ Removing leading orphaned tool_result user message")
messages.pop(0)
removed += 1
else:
break
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
# Removing one broken message can orphan others (e.g. an assistant msg
# with both matched and unmatched tool_use — deleting it orphans the
# previously-matched tool_result). Loop until clean.
for _ in range(5):
use_ids: Set[str] = set()
result_ids: Set[str] = set()
for msg in messages:
for block in (msg.get("content") or []):
if not isinstance(block, dict):
continue
if block.get("type") == "tool_use" and block.get("id"):
use_ids.add(block["id"])
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
result_ids.add(block["tool_use_id"])
bad_use = use_ids - result_ids
bad_result = result_ids - use_ids
if not bad_use and not bad_result:
break
pass_removed = 0
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role")
content = msg.get("content", [])
if not isinstance(content, list):
i += 1
continue
if role == "assistant" and bad_use and any(
isinstance(b, dict) and b.get("type") == "tool_use"
and b.get("id") in bad_use for b in content
):
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
messages.pop(i)
pass_removed += 1
continue
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
has_bad = any(
isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result for b in content
)
if has_bad:
if not _has_block_type(content, "text"):
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
messages.pop(i)
pass_removed += 1
continue
else:
before = len(content)
msg["content"] = [
b for b in content
if not (isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result)
]
pass_removed += before - len(msg["content"])
i += 1
removed += pass_removed
if pass_removed == 0:
break
# 4. Removals above can break adjacency; re-run repair only if something was removed.
if removed:
adj_repairs += _repair_tool_use_adjacency(messages)
if removed:
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
if adj_repairs:
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
return removed + adj_repairs
# ------------------------------------------------------------------ #
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
# ------------------------------------------------------------------ #
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
"""
Return a copy of *messages* (OpenAI format) with any ``role=tool``
messages removed if their ``tool_call_id`` does not match a
``tool_calls[].id`` in a preceding assistant message.
"""
known_ids: Set[str] = set()
cleaned: List[Dict] = []
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id", "")
if tc_id:
known_ids.add(tc_id)
if msg.get("role") == "tool":
ref_id = msg.get("tool_call_id", "")
if ref_id and ref_id not in known_ids:
logger.warning(
f"[MessageSanitizer] Dropping orphaned tool result "
f"(tool_call_id={ref_id} not in known ids)"
)
continue
cleaned.append(msg)
return cleaned
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
def _has_block_type(content: list, block_type: str) -> bool:
return any(
isinstance(b, dict) and b.get("type") == block_type
for b in content
)
def _extract_text_from_content(content) -> str:
"""Extract plain text from a message content field (str or list of blocks)."""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def compress_turn_to_text_only(turn: Dict) -> Dict:
"""
Compress a full turn (with tool_use/tool_result chains) into a lightweight
text-only turn that keeps only the first user text and the last assistant text.
This preserves the conversational context (what the user asked and what the
agent concluded) while stripping out the bulky intermediate tool interactions.
Returns a new turn dict with a ``messages`` list; the original is not mutated.
"""
user_text = ""
last_assistant_text = ""
for msg in turn["messages"]:
role = msg.get("role")
content = msg.get("content", [])
if role == "user":
if isinstance(content, list) and _has_block_type(content, "tool_result"):
continue
if not user_text:
user_text = _extract_text_from_content(content)
elif role == "assistant":
text = _extract_text_from_content(content)
if text:
last_assistant_text = text
compressed_messages = []
if user_text:
compressed_messages.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
if last_assistant_text:
compressed_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": last_assistant_text}]
})
return {"messages": compressed_messages}

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field

View File

@@ -1,3 +1,4 @@
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field

View File

@@ -15,6 +15,7 @@ from agent.skills.types import (
)
from agent.skills.loader import SkillLoader
from agent.skills.manager import SkillManager
from agent.skills.service import SkillService
from agent.skills.formatter import format_skills_for_prompt
__all__ = [
@@ -25,5 +26,6 @@ __all__ = [
"LoadSkillsResult",
"SkillLoader",
"SkillManager",
"SkillService",
"format_skills_for_prompt",
]

View File

@@ -123,17 +123,63 @@ def should_include_skill(
return False
# Check environment variables (API keys)
# Simple rule: All required env vars must be set
# All required env vars must be set
required_env = metadata.requires.get('env', [])
if required_env:
for env_name in required_env:
if not has_env_var(env_name):
# Missing required API key → disable skill
return False
# Check anyEnv (at least one must be present)
any_env = metadata.requires.get('anyEnv', [])
if any_env:
if not any(has_env_var(e) for e in any_env):
return False
return True
def get_missing_requirements(
entry: SkillEntry,
current_platform: Optional[str] = None,
) -> Dict[str, List[str]]:
"""
Return a dict of missing requirements for a skill.
Empty dict means all requirements are met.
:param entry: SkillEntry to check
:param current_platform: Current platform (default: auto-detect)
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
"""
missing: Dict[str, List[str]] = {}
metadata = entry.metadata
if not metadata or not metadata.requires:
return missing
required_bins = metadata.requires.get('bins', [])
if required_bins:
missing_bins = [b for b in required_bins if not has_binary(b)]
if missing_bins:
missing['bins'] = missing_bins
any_bins = metadata.requires.get('anyBins', [])
if any_bins and not has_any_binary(any_bins):
missing['anyBins'] = any_bins
required_env = metadata.requires.get('env', [])
if required_env:
missing_env = [e for e in required_env if not has_env_var(e)]
if missing_env:
missing['env'] = missing_env
any_env = metadata.requires.get('anyEnv', [])
if any_env and not any(has_env_var(e) for e in any_env):
missing['anyEnv'] = any_env
return missing
def is_config_path_truthy(config: Dict, path: str) -> bool:
"""
Check if a config path resolves to a truthy value.

View File

@@ -2,7 +2,7 @@
Skill formatter for generating prompts from skills.
"""
from typing import List
from typing import Dict, List
from agent.skills.types import Skill, SkillEntry
@@ -23,12 +23,10 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
return ""
lines = [
"\n\nThe following skills provide specialized instructions for specific tasks.",
"Use the read tool to load a skill's file when the task matches its description.",
"",
"<available_skills>",
]
for skill in visible_skills:
lines.append(" <skill>")
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
@@ -53,6 +51,71 @@ def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
return format_skills_for_prompt(skills)
def format_unavailable_skills_for_prompt(
entries: List[SkillEntry],
missing_map: Dict[str, Dict[str, List[str]]],
) -> str:
"""
Format unavailable (requires-not-met) skills as brief setup hints
so the AI can guide users to configure them.
:param entries: List of unavailable skill entries
:param missing_map: Dict mapping skill name to its missing requirements
:return: Formatted prompt text
"""
if not entries:
return ""
lines = [
"",
"<unavailable_skills>",
"The following skills are installed but not yet ready. "
"Guide the user to complete the setup when relevant.",
]
for entry in entries:
skill = entry.skill
missing = missing_map.get(skill.name, {})
missing_parts = []
for key, values in missing.items():
missing_parts.append(f"{key}: {', '.join(values)}")
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
setup_hint = _extract_setup_hint(skill)
lines.append(" <skill>")
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
if setup_hint:
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
lines.append(" </skill>")
lines.append("</unavailable_skills>")
return "\n".join(lines)
def _extract_setup_hint(skill: Skill) -> str:
"""
Extract the Setup section from SKILL.md content as a brief hint.
Returns the first few lines of the ## Setup section.
"""
content = skill.content
if not content:
return ""
import re
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
if not match:
return ""
setup_text = match.group(1).strip()
lines = setup_text.split('\n')
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
return ' '.join(hint_lines)[:300]
def _escape_xml(text: str) -> str:
"""Escape XML special characters."""
return (text

View File

@@ -87,8 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
if not isinstance(metadata_raw, dict):
return None
# Use metadata_raw directly (COW format)
meta_obj = metadata_raw
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
meta_obj = _unwrap_metadata_namespace(metadata_raw)
# Parse install specs
install_specs = []
@@ -128,6 +128,7 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
return SkillMetadata(
always=meta_obj.get('always', False),
default_enabled=meta_obj.get('default_enabled', True),
skill_key=meta_obj.get('skillKey'),
primary_env=meta_obj.get('primaryEnv'),
emoji=meta_obj.get('emoji'),
@@ -138,6 +139,25 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
)
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
"""
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
Otherwise return the original dict unchanged.
"""
keys = set(metadata_raw.keys())
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
if len(ns_keys) == 1 and len(keys) == 1:
ns = ns_keys.pop()
inner = metadata_raw[ns]
if isinstance(inner, dict):
return inner
return metadata_raw
def _normalize_string_list(value: Any) -> List[str]:
"""Normalize a value to a list of strings."""
if not value:

View File

@@ -12,25 +12,20 @@ from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_bo
class SkillLoader:
"""Loads skills from various directories."""
def __init__(self, workspace_dir: Optional[str] = None):
"""
Initialize the skill loader.
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
"""
self.workspace_dir = workspace_dir
def __init__(self):
pass
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
"""
Load skills from a directory.
Discovery rules:
- Direct .md files in the root directory
- Recursive SKILL.md files under subdirectories
:param dir_path: Directory path to scan
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
:param source: Source identifier ('builtin' or 'custom')
:return: LoadSkillsResult with skills and diagnostics
"""
skills = []
@@ -96,7 +91,7 @@ class SkillLoader:
continue
# Check if this is a skill file
is_root_md = include_root_files and entry.endswith('.md')
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
is_skill_md = not include_root_files and entry == 'SKILL.md'
if not (is_root_md or is_skill_md):
@@ -137,6 +132,18 @@ class SkillLoader:
name = frontmatter.get('name', parent_dir_name)
description = frontmatter.get('description', '')
# Normalize name (handle both string and list)
if isinstance(name, list):
name = name[0] if name else parent_dir_name
elif not isinstance(name, str):
name = str(name) if name else parent_dir_name
# Normalize description (handle both string and list)
if isinstance(description, list):
description = ' '.join(str(d) for d in description if d)
elif not isinstance(description, str):
description = str(description) if description else ''
# Special handling for linkai-agent: dynamically load apps from config.json
if name == 'linkai-agent':
description = self._load_linkai_agent_description(skill_dir, description)
@@ -176,16 +183,13 @@ class SkillLoader:
import json
config_path = os.path.join(skill_dir, "config.json")
template_path = os.path.join(skill_dir, "config.json.template")
# Try to load config.json or fallback to template
config_file = config_path if os.path.exists(config_path) else template_path
if not os.path.exists(config_file):
return default_description
if not os.path.exists(config_path):
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
return ""
try:
with open(config_file, 'r', encoding='utf-8') as f:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
apps = config.get("apps", [])
@@ -206,61 +210,49 @@ class SkillLoader:
def load_all_skills(
self,
managed_dir: Optional[str] = None,
workspace_skills_dir: Optional[str] = None,
extra_dirs: Optional[List[str]] = None,
builtin_dir: Optional[str] = None,
custom_dir: Optional[str] = None,
) -> Dict[str, SkillEntry]:
"""
Load skills from all configured locations with precedence.
Load skills from builtin and custom directories.
Precedence (lowest to highest):
1. Extra directories
2. Managed skills directory
3. Workspace skills directory
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
:param extra_dirs: Additional directories to load skills from
1. builtin — project root ``skills/``, shipped with the codebase
2. custom — workspace ``skills/``, installed via cloud console or skill creator
Same-name custom skills override builtin ones.
:param builtin_dir: Built-in skills directory
:param custom_dir: Custom skills directory
:return: Dictionary mapping skill name to SkillEntry
"""
skill_map: Dict[str, SkillEntry] = {}
all_diagnostics = []
# Load from extra directories (lowest precedence)
if extra_dirs:
for extra_dir in extra_dirs:
if not os.path.exists(extra_dir):
continue
result = self.load_skills_from_dir(extra_dir, source='extra')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Load from managed directory
if managed_dir and os.path.exists(managed_dir):
result = self.load_skills_from_dir(managed_dir, source='managed')
# Load builtin skills (lower precedence)
if builtin_dir and os.path.exists(builtin_dir):
result = self.load_skills_from_dir(builtin_dir, source='builtin')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Load from workspace directory (highest precedence)
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
# Load custom skills (higher precedence, overrides builtin)
if custom_dir and os.path.exists(custom_dir):
result = self.load_skills_from_dir(custom_dir, source='custom')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Log diagnostics
if all_diagnostics:
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
for diag in all_diagnostics[:5]: # Log first 5
for diag in all_diagnostics[:5]:
logger.debug(f" - {diag}")
logger.debug(f"Loaded {len(skill_map)} skills from all sources")
logger.debug(f"Loaded {len(skill_map)} skills total")
return skill_map
def _create_skill_entry(self, skill: Skill) -> SkillEntry:

View File

@@ -3,6 +3,7 @@ Skill manager for managing skill lifecycle and operations.
"""
import os
import json
from typing import Dict, List, Optional
from pathlib import Path
from common.log import logger
@@ -10,56 +11,139 @@ from agent.skills.types import Skill, SkillEntry, SkillSnapshot
from agent.skills.loader import SkillLoader
from agent.skills.formatter import format_skill_entries_for_prompt
SKILLS_CONFIG_FILE = "skills_config.json"
class SkillManager:
"""Manages skills for an agent."""
def __init__(
self,
workspace_dir: Optional[str] = None,
managed_skills_dir: Optional[str] = None,
extra_dirs: Optional[List[str]] = None,
builtin_dir: Optional[str] = None,
custom_dir: Optional[str] = None,
config: Optional[Dict] = None,
):
"""
Initialize the skill manager.
:param workspace_dir: Agent workspace directory
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
:param extra_dirs: Additional skill directories
:param builtin_dir: Built-in skills directory (project root ``skills/``)
:param custom_dir: Custom skills directory (workspace ``skills/``)
:param config: Configuration dictionary
"""
self.workspace_dir = workspace_dir
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
self.extra_dirs = extra_dirs or []
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
self.config = config or {}
self.loader = SkillLoader(workspace_dir=workspace_dir)
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
# skills_config: full skill metadata keyed by name
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
self.skills_config: Dict[str, dict] = {}
self.loader = SkillLoader()
self.skills: Dict[str, SkillEntry] = {}
# Load skills on initialization
self.refresh_skills()
def _get_default_managed_dir(self) -> str:
"""Get the default managed skills directory."""
# Use project root skills directory as default
import os
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
return os.path.join(project_root, 'skills')
def refresh_skills(self):
"""Reload all skills from configured directories."""
workspace_skills_dir = None
if self.workspace_dir:
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
"""Reload all skills from builtin and custom directories, then sync config."""
self.skills = self.loader.load_all_skills(
managed_dir=self.managed_skills_dir,
workspace_skills_dir=workspace_skills_dir,
extra_dirs=self.extra_dirs,
builtin_dir=self.builtin_dir,
custom_dir=self.custom_dir,
)
self._sync_skills_config()
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
# ------------------------------------------------------------------
# skills_config.json management
# ------------------------------------------------------------------
def _load_skills_config(self) -> Dict[str, dict]:
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
if not os.path.exists(self._skills_config_path):
return {}
try:
with open(self._skills_config_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
return data
except Exception as e:
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
return {}
def _save_skills_config(self):
"""Persist skills_config to custom_dir/skills_config.json."""
os.makedirs(self.custom_dir, exist_ok=True)
try:
with open(self._skills_config_path, "w", encoding="utf-8") as f:
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
except Exception as e:
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
def _sync_skills_config(self):
"""
Merge directory-scanned skills with the persisted config file.
- New skills: use metadata.default_enabled as initial enabled state.
- Existing skills: preserve their persisted enabled state.
- Skills that no longer exist on disk are removed.
- name/description/source are always refreshed from the latest scan.
"""
saved = self._load_skills_config()
merged: Dict[str, dict] = {}
for name, entry in self.skills.items():
skill = entry.skill
prev = saved.get(name, {})
category = prev.get("category", "skill")
if name in saved:
enabled = prev.get("enabled", True)
else:
enabled = entry.metadata.default_enabled if entry.metadata else True
merged[name] = {
"name": name,
"description": skill.description,
"source": prev.get("source") or skill.source,
"enabled": enabled,
"category": category,
}
self.skills_config = merged
self._save_skills_config()
def is_skill_enabled(self, name: str) -> bool:
"""
Check if a skill is enabled according to skills_config.
:param name: skill name
:return: True if enabled (default True if not in config)
"""
entry = self.skills_config.get(name)
if entry is None:
return True
return entry.get("enabled", True)
def set_skill_enabled(self, name: str, enabled: bool):
"""
Set a skill's enabled state and persist.
:param name: skill name
:param enabled: True to enable, False to disable
"""
if name not in self.skills_config:
raise ValueError(f"skill '{name}' not found in config")
self.skills_config[name]["enabled"] = enabled
self._save_skills_config()
def get_skills_config(self) -> Dict[str, dict]:
"""
Return the full skills_config dict (for query API).
:return: copy of skills_config
"""
return dict(self.skills_config)
def get_skill(self, name: str) -> Optional[SkillEntry]:
"""
@@ -78,53 +162,116 @@ class SkillManager:
"""
return list(self.skills.values())
@staticmethod
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
"""Normalize a skill_filter list into a flat list of stripped names."""
if skill_filter is None:
return None
normalized = []
for item in skill_filter:
if isinstance(item, str):
name = item.strip()
if name:
normalized.append(name)
elif isinstance(item, list):
for subitem in item:
if isinstance(subitem, str):
name = subitem.strip()
if name:
normalized.append(name)
return normalized or None
def filter_skills(
self,
skill_filter: Optional[List[str]] = None,
include_disabled: bool = False,
) -> List[SkillEntry]:
"""
Filter skills based on criteria.
Simple rule: Skills are auto-enabled if requirements are met.
- Has required API keys → included
- Missing API keys → excluded
Filter skills that are eligible (enabled + requirements met).
:param skill_filter: List of skill names to include (None = all)
:param include_disabled: Whether to include skills with disable_model_invocation=True
:return: Filtered list of skill entries
:param include_disabled: Whether to include disabled skills
:return: Filtered list of eligible skill entries
"""
from agent.skills.config import should_include_skill
entries = list(self.skills.values())
# Check requirements (platform, binaries, env vars)
entries = [e for e in entries if should_include_skill(e, self.config)]
# Apply skill filter
if skill_filter is not None:
normalized = [name.strip() for name in skill_filter if name.strip()]
if normalized:
entries = [e for e in entries if e.skill.name in normalized]
# Filter out disabled skills unless explicitly requested
normalized = self._normalize_skill_filter(skill_filter)
if normalized is not None:
entries = [e for e in entries if e.skill.name in normalized]
if not include_disabled:
entries = [e for e in entries if not e.skill.disable_model_invocation]
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
return entries
def filter_unavailable_skills(
self,
skill_filter: Optional[List[str]] = None,
) -> tuple:
"""
Find skills that are enabled but have unmet requirements.
:param skill_filter: Optional list of skill names to include
:return: Tuple of (entries, missing_map) where missing_map maps
skill name to its missing requirements dict
"""
from agent.skills.config import should_include_skill, get_missing_requirements
entries = list(self.skills.values())
# Only enabled skills
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
normalized = self._normalize_skill_filter(skill_filter)
if normalized is not None:
entries = [e for e in entries if e.skill.name in normalized]
# Keep only those that fail should_include_skill (requirements not met)
unavailable = []
missing_map: Dict[str, dict] = {}
for e in entries:
if not should_include_skill(e, self.config):
missing = get_missing_requirements(e)
if missing:
unavailable.append(e)
missing_map[e.skill.name] = missing
return unavailable, missing_map
def build_skills_prompt(
self,
skill_filter: Optional[List[str]] = None,
) -> str:
"""
Build a formatted prompt containing available skills.
Build a formatted prompt containing available skills
and brief hints for unavailable ones.
:param skill_filter: Optional list of skill names to include
:return: Formatted skills prompt
"""
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
return format_skill_entries_for_prompt(entries)
from common.log import logger
from agent.skills.formatter import format_unavailable_skills_for_prompt
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
if eligible:
skill_names = [e.skill.name for e in eligible]
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
result = format_skill_entries_for_prompt(eligible)
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
if unavailable:
unavailable_names = [e.skill.name for e in unavailable]
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
return result
def build_skill_snapshot(
self,

285
agent/skills/service.py Normal file
View File

@@ -0,0 +1,285 @@
"""
Skill service for handling skill CRUD operations.
This service provides a unified interface for managing skills, which can be
called from the cloud control client (LinkAI), the local web console, or any
other management entry point.
"""
import os
import shutil
import zipfile
import tempfile
from typing import Dict, List, Optional
from common.log import logger
from agent.skills.types import Skill, SkillEntry
from agent.skills.manager import SkillManager
try:
import requests
except ImportError:
requests = None
class SkillService:
"""
High-level service for skill lifecycle management.
Wraps SkillManager and provides network-aware operations such as
downloading skill files from remote URLs.
"""
def __init__(self, skill_manager: SkillManager):
"""
:param skill_manager: The SkillManager instance to operate on
"""
self.manager = skill_manager
# ------------------------------------------------------------------
# query
# ------------------------------------------------------------------
def query(self) -> List[dict]:
"""
Query all skills and return a serialisable list.
Reads from skills_config.json (refreshes from disk if needed).
:return: list of skill info dicts
"""
self.manager.refresh_skills()
config = self.manager.get_skills_config()
result = list(config.values())
logger.info(f"[SkillService] query: {len(result)} skills found")
return result
# ------------------------------------------------------------------
# add / install
# ------------------------------------------------------------------
def add(self, payload: dict) -> None:
"""
Add (install) a skill from a remote payload.
Supported payload types:
1. ``type: "url"`` download individual files::
{
"name": "web_search",
"type": "url",
"enabled": true,
"files": [
{"url": "https://...", "path": "README.md"},
{"url": "https://...", "path": "scripts/main.py"}
]
}
2. ``type: "package"`` download a zip archive and extract::
{
"name": "plugin-custom-tool",
"type": "package",
"category": "skills",
"enabled": true,
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
}
:param payload: skill add payload from server
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
payload_type = payload.get("type", "url")
if payload_type == "package":
self._add_package(name, payload)
else:
self._add_url(name, payload)
self.manager.refresh_skills()
category = payload.get("category")
if category and name in self.manager.skills_config:
self.manager.skills_config[name]["category"] = category
self.manager._save_skills_config()
def _add_url(self, name: str, payload: dict) -> None:
"""Install a skill by downloading individual files."""
files = payload.get("files", [])
if not files:
raise ValueError("skill files list is empty")
skill_dir = os.path.join(self.manager.custom_dir, name)
tmp_dir = skill_dir + ".tmp"
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir, exist_ok=True)
try:
for file_info in files:
url = file_info.get("url")
rel_path = file_info.get("path")
if not url or not rel_path:
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
continue
dest = os.path.join(tmp_dir, rel_path)
self._download_file(url, dest)
except Exception:
shutil.rmtree(tmp_dir, ignore_errors=True)
raise
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
os.rename(tmp_dir, skill_dir)
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
def _add_package(self, name: str, payload: dict) -> None:
"""
Install a skill by downloading a zip archive and extracting it.
If the archive contains a single top-level directory, that directory
is used as the skill folder directly; otherwise a new directory named
after the skill is created to hold the extracted contents.
"""
files = payload.get("files", [])
if not files or not files[0].get("url"):
raise ValueError("package url is required")
url = files[0]["url"]
skill_dir = os.path.join(self.manager.custom_dir, name)
with tempfile.TemporaryDirectory() as tmp_dir:
zip_path = os.path.join(tmp_dir, "package.zip")
self._download_file(url, zip_path)
if not zipfile.is_zipfile(zip_path):
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
extract_dir = os.path.join(tmp_dir, "extracted")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(extract_dir)
# Determine the actual content root.
# If the zip has a single top-level directory, use its contents
# so the skill folder is clean (no extra nesting).
top_items = [
item for item in os.listdir(extract_dir)
if not item.startswith(".")
]
if len(top_items) == 1:
single = os.path.join(extract_dir, top_items[0])
if os.path.isdir(single):
extract_dir = single
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
shutil.copytree(extract_dir, skill_dir)
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
# ------------------------------------------------------------------
# open / close (enable / disable)
# ------------------------------------------------------------------
def open(self, payload: dict) -> None:
"""
Enable a skill by name.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
self.manager.set_skill_enabled(name, enabled=True)
logger.info(f"[SkillService] open: skill '{name}' enabled")
def close(self, payload: dict) -> None:
"""
Disable a skill by name.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
self.manager.set_skill_enabled(name, enabled=False)
logger.info(f"[SkillService] close: skill '{name}' disabled")
# ------------------------------------------------------------------
# delete
# ------------------------------------------------------------------
def delete(self, payload: dict) -> None:
"""
Delete a skill by removing its directory entirely.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
skill_dir = os.path.join(self.manager.custom_dir, name)
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
else:
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
# Refresh will remove the deleted skill from config automatically
self.manager.refresh_skills()
logger.info(f"[SkillService] delete: skill '{name}' deleted")
# ------------------------------------------------------------------
# dispatch - single entry point for protocol messages
# ------------------------------------------------------------------
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
"""
Dispatch a skill management action and return a protocol-compatible
response dict.
:param action: one of query / add / open / close / delete
:param payload: action-specific payload (may be None for query)
:return: dict with action, code, message, payload
"""
payload = payload or {}
try:
if action == "query":
result_payload = self.query()
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
elif action == "add":
self.add(payload)
elif action == "open":
self.open(payload)
elif action == "close":
self.close(payload)
elif action == "delete":
self.delete(payload)
else:
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
return {"action": action, "code": 200, "message": "success", "payload": None}
except Exception as e:
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
return {"action": action, "code": 500, "message": str(e), "payload": None}
# ------------------------------------------------------------------
# internal helpers
# ------------------------------------------------------------------
@staticmethod
def _download_file(url: str, dest: str):
"""
Download a file from *url* and save to *dest*.
:param url: remote file URL
:param dest: local destination path
"""
if requests is None:
raise RuntimeError("requests library is required for downloading skill files")
dest_dir = os.path.dirname(dest)
if dest_dir:
os.makedirs(dest_dir, exist_ok=True)
resp = requests.get(url, timeout=60)
resp.raise_for_status()
with open(dest, "wb") as f:
f.write(resp.content)
logger.debug(f"[SkillService] downloaded {url} -> {dest}")

View File

@@ -2,6 +2,7 @@
Type definitions for skills system.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
@@ -28,6 +29,7 @@ class SkillInstallSpec:
class SkillMetadata:
"""Metadata for a skill from frontmatter."""
always: bool = False # Always include this skill
default_enabled: bool = True # Initial enabled state when first discovered
skill_key: Optional[str] = None # Override skill key
primary_env: Optional[str] = None # Primary environment variable
emoji: Optional[str] = None
@@ -44,7 +46,7 @@ class Skill:
description: str
file_path: str
base_dir: str
source: str # managed, workspace, bundled, etc.
source: str # builtin or custom
content: str # Full markdown content
disable_model_invocation: bool = False
frontmatter: Dict[str, Any] = field(default_factory=dict)

View File

@@ -45,16 +45,45 @@ def _import_optional_tools():
)
except Exception as e:
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
# WebSearch Tool (conditionally loaded based on API key availability at init time)
try:
from agent.tools.web_search.web_search import WebSearch
tools['WebSearch'] = WebSearch
except ImportError as e:
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] WebSearch failed to load: {e}")
# WebFetch Tool
try:
from agent.tools.web_fetch.web_fetch import WebFetch
tools['WebFetch'] = WebFetch
except ImportError as e:
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] WebFetch failed to load: {e}")
# Vision Tool (conditionally loaded based on API key availability)
try:
from agent.tools.vision.vision import Vision
tools['Vision'] = Vision
except ImportError as e:
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] Vision failed to load: {e}")
return tools
# Load optional tools
_optional_tools = _import_optional_tools()
EnvConfig = _optional_tools.get('EnvConfig')
SchedulerTool = _optional_tools.get('SchedulerTool')
WebSearch = _optional_tools.get('WebSearch')
WebFetch = _optional_tools.get('WebFetch')
Vision = _optional_tools.get('Vision')
GoogleSearch = _optional_tools.get('GoogleSearch')
FileSave = _optional_tools.get('FileSave')
FileSave = _optional_tools.get('FileSave')
Terminal = _optional_tools.get('Terminal')
@@ -92,6 +121,9 @@ __all__ = [
'MemoryGetTool',
'EnvConfig',
'SchedulerTool',
'WebSearch',
'WebFetch',
'Vision',
# Optional tools (may be None if dependencies not available)
# 'BrowserTool'
]

View File

@@ -3,6 +3,7 @@ Bash tool - Execute bash commands
"""
import os
import re
import sys
import subprocess
import tempfile
@@ -11,6 +12,7 @@ from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
from common.log import logger
from common.utils import expand_path
class Bash(BaseTool):
@@ -19,10 +21,11 @@ class Bash(BaseTool):
name: str = "bash"
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
IMPORTANT SAFETY GUIDELINES:
- You can freely create, modify, and delete files within the current workspace
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
- When in doubt, describe the command's purpose and ask for permission before executing"""
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
SAFETY:
- Freely create/modify/delete files within the workspace
- For destructive and out-of-workspace commands, explain and confirm first"""
params: dict = {
"type": "object",
@@ -80,26 +83,32 @@ IMPORTANT SAFETY GUIDELINES:
env = os.environ.copy()
# Load environment variables from ~/.cow/.env if it exists
env_file = os.path.expanduser("~/.cow/.env")
env_file = expand_path("~/.cow/.env")
dotenv_vars = {}
if os.path.exists(env_file):
try:
from dotenv import dotenv_values
env_vars = dotenv_values(env_file)
env.update(env_vars)
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
dotenv_vars = dotenv_values(env_file)
env.update(dotenv_vars)
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
except ImportError:
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
except Exception as e:
logger.debug(f"[Bash] Failed to load .env: {e}")
# getuid() only exists on Unix-like systems
if hasattr(os, 'getuid'):
logger.debug(f"[Bash] Process UID: {os.getuid()}")
else:
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
# Debug logging
logger.debug(f"[Bash] CWD: {self.cwd}")
logger.debug(f"[Bash] Command: {command[:500]}")
logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}")
logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}")
logger.debug(f"[Bash] Python executable: {sys.executable}")
logger.debug(f"[Bash] Process UID: {os.getuid()}")
# On Windows, convert $VAR references to %VAR% for cmd.exe
if sys.platform == "win32":
env["PYTHONIOENCODING"] = "utf-8"
command = self._convert_env_vars_for_windows(command, dotenv_vars)
if command and not command.strip().lower().startswith("chcp"):
command = f"chcp 65001 >nul 2>&1 && {command}"
# Execute command with inherited environment variables
result = subprocess.run(
command,
@@ -108,6 +117,8 @@ IMPORTANT SAFETY GUIDELINES:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="replace",
timeout=timeout,
env=env
)
@@ -131,6 +142,8 @@ IMPORTANT SAFETY GUIDELINES:
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="replace",
timeout=timeout,
env=env
)
@@ -258,3 +271,21 @@ IMPORTANT SAFETY GUIDELINES:
return "This command will recursively delete system directories"
return "" # No warning needed
@staticmethod
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
"""
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
Only converts variables loaded from .env (user-configured API keys etc.)
to avoid breaking $PATH, jq expressions, regex, etc.
"""
if not dotenv_vars:
return command
def replace_match(m):
var_name = m.group(1) or m.group(2)
if var_name in dotenv_vars:
return f"%{var_name}%"
return m.group(0)
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)

View File

@@ -7,6 +7,7 @@ import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
from agent.tools.utils.diff import (
strip_bom,
detect_line_ending,
@@ -178,7 +179,7 @@ class Edit(BaseTool):
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from common.utils import expand_path
# API Key 知识库:常见的环境变量及其描述
@@ -66,7 +67,7 @@ class EnvConfig(BaseTool):
def __init__(self, config: dict = None):
self.config = config or {}
# Store env config in ~/.cow directory (outside workspace for security)
self.env_dir = os.path.expanduser("~/.cow")
self.env_dir = expand_path("~/.cow")
self.env_path = os.path.join(self.env_dir, '.env')
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
# Don't create .env file in __init__ to avoid issues during tool discovery
@@ -201,7 +202,8 @@ class EnvConfig(BaseTool):
"key": key,
"value": self._mask_value(value),
"description": description,
"exists": True
"exists": True,
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
})
else:
return ToolResult.success({

View File

@@ -7,6 +7,7 @@ from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
from common.utils import expand_path
DEFAULT_LIMIT = 500
@@ -51,7 +52,7 @@ class Ls(BaseTool):
absolute_path = self._resolve_path(path)
# Security check: Prevent accessing sensitive config directory
env_config_dir = os.path.expanduser("~/.cow")
env_config_dir = expand_path("~/.cow")
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
return ToolResult.fail(
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
@@ -93,7 +94,7 @@ class Ls(BaseTool):
results.append(entry + '/')
else:
results.append(entry)
except:
except Exception:
# Skip entries we can't stat
continue
@@ -133,7 +134,7 @@ class Ls(BaseTool):
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -77,7 +77,7 @@ class MemoryGetTool(BaseTool):
if not file_path.exists():
return ToolResult.fail(f"Error: File not found: {path}")
content = file_path.read_text()
content = file_path.read_text(encoding='utf-8')
lines = content.split('\n')
# Handle line range

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
from common.utils import expand_path
class Read(BaseTool):
@@ -47,7 +48,8 @@ class Read(BaseTool):
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
self.pdf_extensions = {'.pdf'}
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
# Readable text formats (will be read with truncation)
self.text_extensions = {
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
@@ -56,7 +58,6 @@ class Read(BaseTool):
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
}
def execute(self, args: Dict[str, Any]) -> ToolResult:
@@ -66,10 +67,12 @@ class Read(BaseTool):
:param args: Contains file path and optional offset/limit parameters
:return: File content or error message
"""
path = args.get("path", "").strip()
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
path = args.get("path", "") or args.get("location", "")
path = path.strip() if isinstance(path, str) else ""
offset = args.get("offset")
limit = args.get("limit")
if not path:
return ToolResult.fail("Error: path parameter is required")
@@ -77,7 +80,7 @@ class Read(BaseTool):
absolute_path = self._resolve_path(path)
# Security check: Prevent reading sensitive config files
env_config_path = os.path.expanduser("~/.cow/.env")
env_config_path = expand_path("~/.cow/.env")
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
return ToolResult.fail(
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
@@ -117,7 +120,11 @@ class Read(BaseTool):
# Check if PDF
if file_ext in self.pdf_extensions:
return self._read_pdf(absolute_path, path, offset, limit)
# Check if Office document (.docx, .xlsx, .pptx, etc.)
if file_ext in self.office_extensions:
return self._read_office(absolute_path, path, file_ext, offset, limit)
# Read text file (with truncation for large files)
return self._read_text(absolute_path, path, offset, limit)
@@ -129,7 +136,7 @@ class Read(BaseTool):
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
@@ -237,8 +244,8 @@ class Read(BaseTool):
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
})
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
# Read file (utf-8-sig strips BOM automatically on Windows)
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
content = f.read()
# Truncate content if too long (20K characters max for model context)
@@ -334,6 +341,116 @@ class Read(BaseTool):
except Exception as e:
return ToolResult.fail(f"Error reading file: {str(e)}")
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
offset: int = None, limit: int = None) -> ToolResult:
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
try:
text = self._extract_office_text(absolute_path, file_ext)
except ImportError as e:
return ToolResult.fail(str(e))
except Exception as e:
return ToolResult.fail(f"Error reading Office document: {e}")
if not text or not text.strip():
return ToolResult.success({
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
})
all_lines = text.split('\n')
total_lines = len(all_lines)
start_line = 0
if offset is not None:
if offset < 0:
start_line = max(0, total_lines + offset)
else:
start_line = max(0, offset - 1)
if start_line >= total_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
)
selected_content = text
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
truncation = truncate_head(selected_content)
start_line_display = start_line + 1
output_text = ""
if truncation.truncated:
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
remaining = total_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
output_text = truncation.content
return ToolResult.success({
"content": output_text,
"total_lines": total_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines,
})
@staticmethod
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
"""Extract plain text from an Office document."""
if file_ext in ('.docx', '.doc'):
try:
from docx import Document
except ImportError:
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
doc = Document(absolute_path)
paragraphs = [p.text for p in doc.paragraphs]
for table in doc.tables:
for row in table.rows:
paragraphs.append('\t'.join(cell.text for cell in row.cells))
return '\n'.join(paragraphs)
if file_ext in ('.xlsx', '.xls'):
try:
from openpyxl import load_workbook
except ImportError:
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
wb = load_workbook(absolute_path, read_only=True, data_only=True)
parts = []
for ws in wb.worksheets:
parts.append(f"--- Sheet: {ws.title} ---")
for row in ws.iter_rows(values_only=True):
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
wb.close()
return '\n'.join(parts)
if file_ext in ('.pptx', '.ppt'):
try:
from pptx import Presentation
except ImportError:
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
prs = Presentation(absolute_path)
parts = []
for i, slide in enumerate(prs.slides, 1):
parts.append(f"--- Slide {i} ---")
for shape in slide.shapes:
if shape.has_text_frame:
for para in shape.text_frame.paragraphs:
text = para.text.strip()
if text:
parts.append(text)
return '\n'.join(parts)
return ""
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read PDF file content

View File

@@ -6,6 +6,7 @@ import os
from typing import Optional
from config import conf
from common.log import logger
from common.utils import expand_path
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
@@ -31,7 +32,7 @@ def init_scheduler(agent_bridge) -> bool:
from agent.tools.scheduler.scheduler_service import SchedulerService
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
# Create task store
@@ -112,11 +113,15 @@ def _execute_agent_task(task: dict, agent_bridge):
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
# Format: scheduler_<receiver>_<task_id> to ensure isolation
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
# Create context for Agent
context = Context(ContextType.TEXT, task_description)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = receiver
context["session_id"] = scheduler_session_id
# Channel-specific setup
if channel_type == "web":
@@ -129,18 +134,20 @@ def _execute_agent_task(task: dict, agent_bridge):
elif channel_type == "dingtalk":
# DingTalk requires msg object, set to None for scheduled tasks
context["msg"] = None
# 如果是单聊,需要传递 sender_staff_id
if not is_group:
sender_staff_id = action.get("dingtalk_sender_staff_id")
if sender_staff_id:
context["dingtalk_sender_staff_id"] = sender_staff_id
elif channel_type == "wecom_bot":
context["msg"] = None
# Use Agent to execute the task
# Mark this as a scheduled task execution to prevent recursive task creation
context["is_scheduled_task"] = True
try:
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True)
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
if reply and reply.content:
# Send the reply via channel
@@ -228,7 +235,11 @@ def _execute_send_message(task: dict, agent_bridge):
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
else:
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
elif channel_type == "wecom_bot":
context["msg"] = None
elif channel_type == "qq":
context["msg"] = None
# Create reply
reply = Reply(ReplyType.TEXT, content)
@@ -321,31 +332,31 @@ def _execute_tool_call(task: dict, agent_bridge):
context["request_id"] = request_id
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
elif channel_type == "feishu":
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
context["receive_id_type"] = "chat_id" if is_group else "open_id"
context["msg"] = None
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
elif channel_type == "wecom_bot":
context["msg"] = None
reply = Reply(ReplyType.TEXT, content)
# Get channel and send
from channel.channel_factory import create_channel
try:
channel = create_channel(channel_type)
if channel:
# For web channel, register the request_id to session mapping
if channel_type == "web" and hasattr(channel, 'request_to_session'):
channel.request_to_session[request_id] = receiver
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
channel.send(reply, context)
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
else:
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
except Exception as e:
logger.error(f"[Scheduler] Failed to send tool result: {e}")
except Exception as e:
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
@@ -378,6 +389,10 @@ def _execute_skill_call(task: dict, agent_bridge):
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
# Format: scheduler_<receiver>_<task_id> to ensure isolation
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
# Build a natural language query for the Agent to execute the skill
# Format: "Use skill-name to do something with params"
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
@@ -389,7 +404,7 @@ def _execute_skill_call(task: dict, agent_bridge):
context = Context(ContextType.TEXT, query)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = receiver
context["session_id"] = scheduler_session_id
# Channel-specific setup
if channel_type == "web":
@@ -399,10 +414,13 @@ def _execute_skill_call(task: dict, agent_bridge):
elif channel_type == "feishu":
context["receive_id_type"] = "chat_id" if is_group else "open_id"
context["msg"] = None
elif channel_type == "wecom_bot":
context["msg"] = None
# Use Agent to execute the skill
try:
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True)
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
if reply and reply.content:
content = reply.content
@@ -440,8 +458,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
if context:
tool.current_context = context
# Also set channel_type from config
channel_type = conf().get("channel_type", "unknown")
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
if not tool.config:
tool.config = {}
tool.config["channel_type"] = channel_type

View File

@@ -61,8 +61,7 @@ class SchedulerService:
self._check_and_execute_tasks()
except Exception as e:
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
# Sleep for 30 seconds between checks
time.sleep(30)
def _check_and_execute_tasks(self):
@@ -85,12 +84,9 @@ class SchedulerService:
"last_run_at": now.isoformat()
})
else:
# One-time task, disable it
self.task_store.update_task(task['id'], {
"enabled": False,
"last_run_at": now.isoformat()
})
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
# One-time task completed, remove it
self.task_store.delete_task(task['id'])
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
except Exception as e:
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
@@ -127,14 +123,11 @@ class SchedulerService:
if time_diff > 300: # 5 minutes
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
# For one-time tasks, disable them
# For one-time tasks, remove them directly
schedule = task.get("schedule", {})
if schedule.get("type") == "once":
self.task_store.update_task(task['id'], {
"enabled": False,
"last_run_at": now.isoformat()
})
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
self.task_store.delete_task(task['id'])
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
return False
# For recurring tasks, calculate next run from now
@@ -147,7 +140,7 @@ class SchedulerService:
return False
return now >= next_run
except:
except Exception:
return False
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
@@ -195,7 +188,7 @@ class SchedulerService:
# Only return if in the future
if run_at > from_time:
return run_at
except:
except Exception:
pass
return None

View File

@@ -20,7 +20,8 @@ class SchedulerTool(BaseTool):
name: str = "scheduler"
description: str = (
"创建、查询和管理定时任务。支持固定消息和AI任务两种类型\n\n"
"创建、查询和管理定时任务(提醒、周期性任务等)\n\n"
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
"使用方法:\n"
"- 创建action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
"- 查询action='list' / action='get', task_id='任务ID'\n"
@@ -53,7 +54,7 @@ class SchedulerTool(BaseTool):
},
"ai_task": {
"type": "string",
"description": "AI任务描述 (与message二选一)'搜索今日新闻''查询天气'"
"description": "AI任务描述 (与message二选一)用于定时让AI执行的任务"
},
"schedule_type": {
"type": "string",
@@ -423,7 +424,7 @@ class SchedulerTool(BaseTool):
try:
dt = datetime.fromisoformat(run_at)
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
except:
except Exception:
return "一次性"
return "未知"
@@ -437,6 +438,6 @@ class SchedulerTool(BaseTool):
return msg.other_user_nickname or "群聊"
else:
return msg.from_user_nickname or "用户"
except:
except Exception:
pass
return "未知"

View File

@@ -8,6 +8,7 @@ import threading
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from common.utils import expand_path
class TaskStore:
@@ -24,7 +25,7 @@ class TaskStore:
"""
if store_path is None:
# Default to ~/cow/scheduler/tasks.json
home = os.path.expanduser("~")
home = expand_path("~")
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
self.store_path = store_path
@@ -71,7 +72,7 @@ class TaskStore:
with open(self.store_path, 'r') as src:
with open(backup_path, 'w') as dst:
dst.write(src.read())
except:
except Exception:
pass
# Save tasks

View File

@@ -7,20 +7,21 @@ from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
class Send(BaseTool):
"""Tool for sending files to the user"""
name: str = "send"
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to send. Can be absolute path or relative to workspace."
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
},
"message": {
"type": "string",
@@ -102,7 +103,7 @@ class Send(BaseTool):
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
path = os.path.expanduser(path)
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -8,7 +8,7 @@ Truncation is based on two independent limits - whichever is hit first wins:
Never returns partial lines (except bash tail truncation edge case).
"""
from typing import Dict, Any, Optional, Literal
from typing import Dict, Any, Optional, Literal, Tuple
DEFAULT_MAX_LINES = 2000
@@ -278,7 +278,7 @@ def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
return encoded[start:].decode('utf-8', errors='ignore')
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
"""
Truncate single line to max characters, add [truncated] suffix.
Used for grep match lines.

View File

@@ -0,0 +1 @@
from agent.tools.vision.vision import Vision

View File

@@ -0,0 +1,280 @@
"""
Vision tool - Analyze images using OpenAI-compatible Vision API.
Supports local files (auto base64-encoded) and HTTP URLs.
Providers: OpenAI (preferred) > LinkAI (fallback).
"""
import base64
import os
import subprocess
import tempfile
from typing import Any, Dict, Optional, Tuple
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from config import conf
DEFAULT_MODEL = "gpt-4.1-mini"
DEFAULT_TIMEOUT = 60
MAX_TOKENS = 1000
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
SUPPORTED_EXTENSIONS = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
}
class Vision(BaseTool):
"""Analyze images using OpenAI-compatible Vision API"""
name: str = "vision"
description: str = (
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
"Can describe content, extract text, identify objects, colors, etc. "
"Requires OPENAI_API_KEY or LINKAI_API_KEY."
)
params: dict = {
"type": "object",
"properties": {
"image": {
"type": "string",
"description": "Local file path or HTTP(S) URL of the image to analyze",
},
"question": {
"type": "string",
"description": "Question to ask about the image",
},
"model": {
"type": "string",
"description": (
f"Vision model to use (default: {DEFAULT_MODEL}). "
"Options: gpt-4.1-mini, gpt-4.1, gpt-4o-mini, gpt-4o"
),
},
},
"required": ["image", "question"],
}
def __init__(self, config: dict = None):
self.config = config or {}
@staticmethod
def is_available() -> bool:
return bool(
conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
or conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
)
def execute(self, args: Dict[str, Any]) -> ToolResult:
image = args.get("image", "").strip()
question = args.get("question", "").strip()
model = args.get("model", DEFAULT_MODEL).strip() or DEFAULT_MODEL
if not image:
return ToolResult.fail("Error: 'image' parameter is required")
if not question:
return ToolResult.fail("Error: 'question' parameter is required")
api_key, api_base, extra_headers = self._resolve_provider()
if not api_key:
return ToolResult.fail(
"Error: No API key configured for Vision.\n"
"Please configure one of the following using env_config tool:\n"
" 1. OPENAI_API_KEY (preferred): env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
" 2. LINKAI_API_KEY (fallback): env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")\n\n"
"Get your key at: https://platform.openai.com/api-keys or https://link-ai.tech"
)
try:
image_content = self._build_image_content(image)
except Exception as e:
return ToolResult.fail(f"Error: {e}")
try:
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
except requests.Timeout:
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail("Error: Failed to connect to Vision API")
except Exception as e:
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Vision API call failed - {e}")
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
if api_key:
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return api_key, self._ensure_v1(api_base), {}
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
if api_key:
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
or "https://api.link-ai.tech"
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
from common.utils import get_cloud_headers
extra = get_cloud_headers(api_key)
extra.pop("Authorization", None)
extra.pop("Content-Type", None)
return api_key, self._ensure_v1(api_base), extra
return None, "", {}
@staticmethod
def _ensure_v1(api_base: str) -> str:
"""Append /v1 if the base URL doesn't already end with a versioned path."""
if not api_base:
return api_base
# Already has /v1 or similar version suffix
if api_base.rstrip("/").split("/")[-1].startswith("v"):
return api_base
return api_base.rstrip("/") + "/v1"
def _build_image_content(self, image: str) -> dict:
"""Build the image_url content block for the API request."""
if image.startswith(("http://", "https://")):
return {"type": "image_url", "image_url": {"url": image}}
if not os.path.isfile(image):
raise FileNotFoundError(f"Image file not found: {image}")
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
mime_type = SUPPORTED_EXTENSIONS.get(ext)
if not mime_type:
raise ValueError(
f"Unsupported image format '.{ext}'. "
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
)
file_path = self._maybe_compress(image)
try:
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
finally:
if file_path != image and os.path.exists(file_path):
os.remove(file_path)
data_url = f"data:{mime_type};base64,{b64}"
return {"type": "image_url", "image_url": {"url": data_url}}
@staticmethod
def _maybe_compress(path: str) -> str:
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
file_size = os.path.getsize(path)
if file_size <= COMPRESS_THRESHOLD:
return path
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
tmp.close()
def _try_sips(max_dim: str, quality: str) -> bool:
try:
subprocess.run(
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
path, "--out", tmp.name],
capture_output=True, check=True,
)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
def _try_convert(max_dim: str, quality: str) -> bool:
try:
subprocess.run(
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
"-quality", quality, tmp.name],
capture_output=True, check=True,
)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
attempts = [
("1536", "85"),
("1536", "70"),
("1536", "50"),
]
for max_dim, quality in attempts:
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
if not ok:
continue
new_size = os.path.getsize(tmp.name)
logger.debug(f"[Vision] Compressed image "
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
f"max_dim={max_dim}, q={quality})")
if new_size <= COMPRESS_THRESHOLD:
return tmp.name
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
return tmp.name
os.remove(tmp.name)
return path
def _call_api(self, api_key: str, api_base: str, model: str,
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": question},
image_content,
],
}
],
"max_tokens": MAX_TOKENS,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
}
resp = requests.post(
f"{api_base}/chat/completions",
headers=headers,
json=payload,
timeout=DEFAULT_TIMEOUT,
)
if resp.status_code == 401:
return ToolResult.fail("Error: Invalid API key. Please check your configuration.")
if resp.status_code == 429:
return ToolResult.fail("Error: API rate limit reached. Please try again later.")
if resp.status_code != 200:
return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}")
data = resp.json()
if "error" in data:
msg = data["error"].get("message", "Unknown API error")
return ToolResult.fail(f"Error: Vision API error - {msg}")
content = ""
choices = data.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "")
usage = data.get("usage", {})
result = {
"model": model,
"content": content,
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}
return ToolResult.success(result)

View File

View File

@@ -0,0 +1,444 @@
"""
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
Supports:
- HTML web pages: extracts readable text content
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
"""
import os
import re
import uuid
from typing import Dict, Any, Optional, Set
from urllib.parse import urlparse, unquote
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size
from common.log import logger
DEFAULT_TIMEOUT = 30
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
"Accept": "*/*",
}
# Supported document file extensions
PDF_SUFFIXES: Set[str] = {".pdf"}
WORD_SUFFIXES: Set[str] = {".docx"}
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
_META_HTTP_EQUIV_RE = re.compile(
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
re.IGNORECASE,
)
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
"""Extract charset from Content-Type header value."""
m = _CHARSET_RE.search(content_type)
return m.group(1) if m else None
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
m = _META_CHARSET_RE.search(raw_bytes)
if m:
return m.group(1).decode("ascii", errors="ignore")
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
if m:
return m.group(1).decode("ascii", errors="ignore")
return None
def _get_url_suffix(url: str) -> str:
"""Extract file extension from URL path, ignoring query params."""
path = urlparse(url).path
return os.path.splitext(path)[-1].lower()
def _is_document_url(url: str) -> bool:
"""Check if URL points to a downloadable document file."""
suffix = _get_url_suffix(url)
return suffix in ALL_DOC_SUFFIXES
class WebFetch(BaseTool):
"""Tool for fetching web pages and remote document files"""
name: str = "web_fetch"
description: str = (
"Fetch content from a http/https URL. For web pages, extracts readable text. "
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
)
params: dict = {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
}
},
"required": ["url"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
url = args.get("url", "").strip()
if not url:
return ToolResult.fail("Error: 'url' parameter is required")
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
if _is_document_url(url):
return self._fetch_document(url)
return self._fetch_webpage(url)
# ---- Web page fetching ----
def _fetch_webpage(self, url: str) -> ToolResult:
"""Fetch and extract readable text from an HTML web page."""
parsed = urlparse(url)
try:
response = requests.get(
url,
headers=DEFAULT_HEADERS,
timeout=DEFAULT_TIMEOUT,
allow_redirects=True,
)
response.raise_for_status()
except requests.Timeout:
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
except requests.HTTPError as e:
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
except Exception as e:
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
content_type = response.headers.get("Content-Type", "")
if self._is_binary_content_type(content_type) and not _is_document_url(url):
return self._handle_download_by_content_type(url, response, content_type)
response.encoding = self._detect_encoding(response)
html = response.text
title = self._extract_title(html)
text = self._extract_text(html)
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
# ---- Document fetching ----
def _fetch_document(self, url: str) -> ToolResult:
"""Download a document file and extract its text content."""
suffix = _get_url_suffix(url)
parsed = urlparse(url)
filename = self._extract_filename(url)
tmp_dir = self._ensure_tmp_dir()
local_path = os.path.join(tmp_dir, filename)
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
try:
response = requests.get(
url,
headers=DEFAULT_HEADERS,
timeout=DEFAULT_TIMEOUT,
stream=True,
allow_redirects=True,
)
response.raise_for_status()
content_length = int(response.headers.get("Content-Length", 0))
if content_length > MAX_FILE_SIZE:
return ToolResult.fail(
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
)
downloaded = 0
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
downloaded += len(chunk)
if downloaded > MAX_FILE_SIZE:
f.close()
os.remove(local_path)
return ToolResult.fail(
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
)
f.write(chunk)
except requests.Timeout:
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
except requests.HTTPError as e:
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
except Exception as e:
self._cleanup_file(local_path)
return ToolResult.fail(f"Error: Failed to download file: {e}")
try:
text = self._parse_document(local_path, suffix)
except Exception as e:
self._cleanup_file(local_path)
return ToolResult.fail(f"Error: Failed to parse document: {e}")
if not text or not text.strip():
file_size = os.path.getsize(local_path)
return ToolResult.success(
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
f"No text content could be extracted. The file may contain only images or be encrypted."
)
truncation = truncate_head(text)
result_text = truncation.content
file_size = os.path.getsize(local_path)
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
if truncation.truncated:
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
return ToolResult.success(header + result_text)
def _parse_document(self, file_path: str, suffix: str) -> str:
"""Parse document file and return extracted text."""
if suffix in PDF_SUFFIXES:
return self._parse_pdf(file_path)
elif suffix in WORD_SUFFIXES:
return self._parse_word(file_path)
elif suffix in TEXT_SUFFIXES:
return self._parse_text(file_path)
elif suffix in SPREADSHEET_SUFFIXES:
return self._parse_spreadsheet(file_path)
elif suffix in PPT_SUFFIXES:
return self._parse_ppt(file_path)
else:
return self._parse_text(file_path)
def _parse_pdf(self, file_path: str) -> str:
"""Extract text from PDF using pypdf."""
try:
from pypdf import PdfReader
except ImportError:
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
reader = PdfReader(file_path)
text_parts = []
for page_num, page in enumerate(reader.pages, 1):
page_text = page.extract_text()
if page_text and page_text.strip():
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
return "\n\n".join(text_parts)
def _parse_word(self, file_path: str) -> str:
"""Extract text from Word documents (.docx)."""
try:
from docx import Document
except ImportError:
raise ImportError(
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
)
doc = Document(file_path)
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
return "\n\n".join(paragraphs)
def _parse_text(self, file_path: str) -> str:
"""Read plain text files (txt, md, csv, etc.)."""
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
for enc in encodings:
try:
with open(file_path, "r", encoding=enc) as f:
return f.read()
except (UnicodeDecodeError, UnicodeError):
continue
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
def _parse_spreadsheet(self, file_path: str) -> str:
"""Extract text from Excel files (.xls/.xlsx)."""
try:
import openpyxl
except ImportError:
raise ImportError(
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
)
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
result_parts = []
for sheet_name in wb.sheetnames:
ws = wb[sheet_name]
rows = []
for row in ws.iter_rows(values_only=True):
cells = [str(c) if c is not None else "" for c in row]
if any(cells):
rows.append(" | ".join(cells))
if rows:
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
wb.close()
return "\n\n".join(result_parts)
def _parse_ppt(self, file_path: str) -> str:
"""Extract text from PowerPoint files (.ppt/.pptx)."""
try:
from pptx import Presentation
except ImportError:
raise ImportError(
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
)
prs = Presentation(file_path)
text_parts = []
for slide_num, slide in enumerate(prs.slides, 1):
slide_texts = []
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
text = paragraph.text.strip()
if text:
slide_texts.append(text)
if slide_texts:
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
return "\n\n".join(text_parts)
# ---- Encoding detection ----
@staticmethod
def _detect_encoding(response: requests.Response) -> str:
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
# 1. Check Content-Type header for explicit charset
content_type = response.headers.get("Content-Type", "")
charset = _extract_charset_from_content_type(content_type)
if charset:
return charset
# 2. Scan raw bytes for HTML meta charset declaration
raw = response.content[:4096]
charset = _extract_charset_from_html_meta(raw)
if charset:
return charset
# 3. Use apparent_encoding (chardet-based detection) if confident enough
apparent = response.apparent_encoding
if apparent:
apparent_lower = apparent.lower()
# Trust CJK / Windows encodings detected by chardet
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
return apparent
# 4. Fallback
return "utf-8"
# ---- Helper methods ----
def _ensure_tmp_dir(self) -> str:
"""Ensure workspace/tmp directory exists and return its path."""
tmp_dir = os.path.join(self.cwd, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
def _extract_filename(self, url: str) -> str:
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
path = urlparse(url).path
basename = os.path.basename(unquote(path))
if not basename or basename == "/":
basename = "downloaded_file"
# Sanitize: keep only safe chars
basename = re.sub(r'[^\w.\-]', '_', basename)
short_id = uuid.uuid4().hex[:8]
return f"{short_id}_{basename}"
@staticmethod
def _cleanup_file(path: str):
"""Remove a file if it exists, ignoring errors."""
try:
if os.path.exists(path):
os.remove(path)
except Exception:
pass
@staticmethod
def _is_binary_content_type(content_type: str) -> bool:
"""Check if Content-Type indicates a binary/document response."""
binary_types = [
"application/pdf",
"application/vnd.openxmlformats",
"application/vnd.ms-excel",
"application/vnd.ms-powerpoint",
"application/octet-stream",
]
ct_lower = content_type.lower()
return any(bt in ct_lower for bt in binary_types)
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
"""Handle a URL that returned binary content instead of HTML."""
ct_lower = content_type.lower()
suffix_map = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
"application/vnd.ms-powerpoint": ".ppt",
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
}
detected_suffix = None
for ct_prefix, ext in suffix_map.items():
if ct_prefix in ct_lower:
detected_suffix = ext
break
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
# Re-fetch as document
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
else self._rewrite_url_with_suffix(url, detected_suffix))
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
@staticmethod
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
parsed = urlparse(url)
new_path = parsed.path.rstrip("/") + suffix
return parsed._replace(path=new_path).geturl()
# ---- HTML extraction (unchanged) ----
@staticmethod
def _extract_title(html: str) -> str:
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
return match.group(1).strip() if match else "Untitled"
@staticmethod
def _extract_text(html: str) -> str:
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
text = re.sub(r"<[^>]+>", "", text)
text = text.replace("&amp;", "&").replace("&lt;", "<").replace("&gt;", ">")
text = text.replace("&quot;", '"').replace("&#39;", "'").replace("&nbsp;", " ")
text = re.sub(r"[^\S\n]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
lines = [line.strip() for line in text.splitlines()]
text = "\n".join(lines)
return text.strip()

View File

@@ -0,0 +1,3 @@
from agent.tools.web_search.web_search import WebSearch
__all__ = ["WebSearch"]

View File

@@ -0,0 +1,318 @@
"""
Web Search tool - Search the web using Bocha or LinkAI search API.
Supports two backends with unified response format:
1. Bocha Search (primary, requires BOCHA_API_KEY)
2. LinkAI Search (fallback, requires LINKAI_API_KEY)
"""
import os
import json
from typing import Dict, Any, Optional
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from config import conf
# Default timeout for API requests (seconds)
DEFAULT_TIMEOUT = 30
class WebSearch(BaseTool):
"""Tool for searching the web using Bocha or LinkAI search API"""
name: str = "web_search"
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query string"
},
"count": {
"type": "integer",
"description": "Number of results to return (1-50, default: 10)"
},
"freshness": {
"type": "string",
"description": (
"Time range filter. Options: "
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
"or date range like '2025-01-01..2025-02-01'"
)
},
"summary": {
"type": "boolean",
"description": "Whether to include text summary for each result (default: false)"
}
},
"required": ["query"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self._backend = None # Will be resolved on first execute
@staticmethod
def is_available() -> bool:
"""Check if web search is available (at least one API key is configured)"""
return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY"))
def _resolve_backend(self) -> Optional[str]:
"""
Determine which search backend to use.
Priority: Bocha > LinkAI
:return: 'bocha', 'linkai', or None
"""
if os.environ.get("BOCHA_API_KEY"):
return "bocha"
if os.environ.get("LINKAI_API_KEY"):
return "linkai"
return None
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute web search
:param args: Search parameters (query, count, freshness, summary)
:return: Search results
"""
query = args.get("query", "").strip()
if not query:
return ToolResult.fail("Error: 'query' parameter is required")
count = args.get("count", 10)
freshness = args.get("freshness", "noLimit")
summary = args.get("summary", False)
# Validate count
if not isinstance(count, int) or count < 1 or count > 50:
count = 10
# Resolve backend
backend = self._resolve_backend()
if not backend:
return ToolResult.fail(
"Error: No search API key configured. "
"Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n"
" - Bocha Search: https://open.bocha.cn\n"
" - LinkAI Search: https://link-ai.tech"
)
try:
if backend == "bocha":
return self._search_bocha(query, count, freshness, summary)
else:
return self._search_linkai(query, count, freshness)
except requests.Timeout:
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail("Error: Failed to connect to search API")
except Exception as e:
logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Search failed - {str(e)}")
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
"""
Search using Bocha API
:param query: Search query
:param count: Number of results
:param freshness: Time range filter
:param summary: Whether to include summary
:return: Formatted search results
"""
api_key = os.environ.get("BOCHA_API_KEY", "")
url = "https://api.bocha.cn/v1/web-search"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json"
}
payload = {
"query": query,
"count": count,
"freshness": freshness,
"summary": summary
}
logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}")
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
if response.status_code == 401:
return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.")
if response.status_code == 403:
return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn")
if response.status_code == 429:
return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.")
if response.status_code != 200:
return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}")
data = response.json()
# Check API-level error code
api_code = data.get("code")
if api_code is not None and api_code != 200:
msg = data.get("msg") or "Unknown error"
return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}")
# Extract and format results
return self._format_bocha_results(data, query)
def _format_bocha_results(self, data: dict, query: str) -> ToolResult:
"""
Format Bocha API response into unified result structure
:param data: Raw API response
:param query: Original query
:return: Formatted ToolResult
"""
search_data = data.get("data", {})
web_pages = search_data.get("webPages", {})
pages = web_pages.get("value", [])
if not pages:
return ToolResult.success({
"query": query,
"backend": "bocha",
"total": 0,
"results": [],
"message": "No results found"
})
results = []
for page in pages:
result = {
"title": page.get("name", ""),
"url": page.get("url", ""),
"snippet": page.get("snippet", ""),
"siteName": page.get("siteName", ""),
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
}
# Include summary only if present
if page.get("summary"):
result["summary"] = page["summary"]
results.append(result)
total = web_pages.get("totalEstimatedMatches", len(results))
return ToolResult.success({
"query": query,
"backend": "bocha",
"total": total,
"count": len(results),
"results": results
})
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
"""
Search using LinkAI plugin API
:param query: Search query
:param count: Number of results
:param freshness: Time range filter
:return: Formatted search results
"""
api_key = os.environ.get("LINKAI_API_KEY", "")
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
from common.utils import get_cloud_headers
headers = get_cloud_headers(api_key)
payload = {
"code": "web-search",
"args": {
"query": query,
"count": count,
"freshness": freshness
}
}
logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}")
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
if response.status_code == 401:
return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.")
if response.status_code != 200:
return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}")
data = response.json()
if not data.get("success"):
msg = data.get("message") or "Unknown error"
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
return self._format_linkai_results(data, query)
def _format_linkai_results(self, data: dict, query: str) -> ToolResult:
"""
Format LinkAI API response into unified result structure.
LinkAI returns the search data in data.data field, which follows
the same Bing-compatible format as Bocha.
:param data: Raw API response
:param query: Original query
:return: Formatted ToolResult
"""
raw_data = data.get("data", "")
# LinkAI may return data as a JSON string
if isinstance(raw_data, str):
try:
raw_data = json.loads(raw_data)
except (json.JSONDecodeError, TypeError):
# If data is plain text, return it as a single result
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": 1,
"count": 1,
"results": [{"content": raw_data}]
})
# If the response follows Bing-compatible structure
if isinstance(raw_data, dict):
web_pages = raw_data.get("webPages", {})
pages = web_pages.get("value", [])
if pages:
results = []
for page in pages:
result = {
"title": page.get("name", ""),
"url": page.get("url", ""),
"snippet": page.get("snippet", ""),
"siteName": page.get("siteName", ""),
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
}
if page.get("summary"):
result["summary"] = page["summary"]
results.append(result)
total = web_pages.get("totalEstimatedMatches", len(results))
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": total,
"count": len(results),
"results": results
})
# Fallback: return raw data
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": 1,
"count": 1,
"results": [{"content": str(raw_data)}]
})

View File

@@ -8,6 +8,7 @@ from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
class Write(BaseTool):
@@ -90,7 +91,7 @@ class Write(BaseTool):
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

307
app.py
View File

@@ -7,11 +7,260 @@ import time
from channel import channel_factory
from common import const
from config import load_config
from common.log import logger
from config import load_config, conf
from plugins import *
import threading
_channel_mgr = None
def get_channel_manager():
return _channel_mgr
def _parse_channel_type(raw) -> list:
"""
Parse channel_type config value into a list of channel names.
Supports:
- single string: "feishu"
- comma-separated string: "feishu, dingtalk"
- list: ["feishu", "dingtalk"]
"""
if isinstance(raw, list):
return [ch.strip() for ch in raw if ch.strip()]
if isinstance(raw, str):
return [ch.strip() for ch in raw.split(",") if ch.strip()]
return []
class ChannelManager:
"""
Manage the lifecycle of multiple channels running concurrently.
Each channel.startup() runs in its own daemon thread.
The web channel is started as default console unless explicitly disabled.
"""
def __init__(self):
self._channels = {} # channel_name -> channel instance
self._threads = {} # channel_name -> thread
self._primary_channel = None
self._lock = threading.Lock()
self.cloud_mode = False # set to True when cloud client is active
@property
def channel(self):
"""Return the primary (first non-web) channel for backward compatibility."""
return self._primary_channel
def get_channel(self, channel_name: str):
return self._channels.get(channel_name)
def start(self, channel_names: list, first_start: bool = False):
"""
Create and start one or more channels in sub-threads.
If first_start is True, plugins and linkai client will also be initialized.
"""
with self._lock:
channels = []
for name in channel_names:
ch = channel_factory.create_channel(name)
ch.cloud_mode = self.cloud_mode
self._channels[name] = ch
channels.append((name, ch))
if self._primary_channel is None and name != "web":
self._primary_channel = ch
if self._primary_channel is None and channels:
self._primary_channel = channels[0][1]
if first_start:
PluginManager().load_plugins()
# Cloud client is optional. It is only started when
# use_linkai=True AND cloud_deployment_id is set.
# By default neither is configured, so the app runs
# entirely locally without any remote connection.
if conf().get("use_linkai") and (
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
):
try:
from common import cloud_client
threading.Thread(
target=cloud_client.start,
args=(self._primary_channel, self),
daemon=True,
).start()
except Exception:
pass
# Start web console first so its logs print cleanly,
# then start remaining channels after a brief pause.
web_entry = None
other_entries = []
for entry in channels:
if entry[0] == "web":
web_entry = entry
else:
other_entries.append(entry)
ordered = ([web_entry] if web_entry else []) + other_entries
for i, (name, ch) in enumerate(ordered):
if i > 0 and name != "web":
time.sleep(0.1)
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
self._threads[name] = t
t.start()
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
def _run_channel(self, name: str, channel):
try:
channel.startup()
except Exception as e:
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
logger.exception(e)
def stop(self, channel_name: str = None):
"""
Stop channel(s). If channel_name is given, stop only that channel;
otherwise stop all channels.
"""
# Pop under lock, then stop outside lock to avoid deadlock
with self._lock:
names = [channel_name] if channel_name else list(self._channels.keys())
to_stop = []
for name in names:
ch = self._channels.pop(name, None)
th = self._threads.pop(name, None)
to_stop.append((name, ch, th))
if channel_name and self._primary_channel is self._channels.get(channel_name):
self._primary_channel = None
for name, ch, th in to_stop:
if ch is None:
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
if th and th.is_alive():
self._interrupt_thread(th, name)
continue
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
graceful = False
if hasattr(ch, 'stop'):
try:
ch.stop()
graceful = True
except Exception as e:
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
if th and th.is_alive():
th.join(timeout=5)
if th.is_alive():
if graceful:
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
"leaving daemon thread to finish on its own")
else:
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
self._interrupt_thread(th, name)
@staticmethod
def _interrupt_thread(th: threading.Thread, name: str):
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
import ctypes
try:
tid = th.ident
if tid is None:
return
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
)
if res == 1:
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
except Exception as e:
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
def restart(self, new_channel_name: str):
"""
Restart a single channel with a new channel type.
Can be called from any thread (e.g. linkai config callback).
"""
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
self.stop(new_channel_name)
_clear_singleton_cache(new_channel_name)
time.sleep(1)
self.start([new_channel_name], first_start=False)
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
def add_channel(self, channel_name: str):
"""
Dynamically add and start a new channel.
If the channel is already running, restart it instead.
"""
with self._lock:
if channel_name in self._channels:
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
if self._channels.get(channel_name):
self.restart(channel_name)
return
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
_clear_singleton_cache(channel_name)
self.start([channel_name], first_start=False)
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
def remove_channel(self, channel_name: str):
"""
Dynamically stop and remove a running channel.
"""
with self._lock:
if channel_name not in self._channels:
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
return
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
self.stop(channel_name)
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
def _clear_singleton_cache(channel_name: str):
"""
Clear the singleton cache for the channel class so that
a new instance can be created with updated config.
"""
cls_map = {
"web": "channel.web.web_channel.WebChannel",
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
const.QQ: "channel.qq.qq_channel.QQChannel",
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
"wx": "channel.weixin.weixin_channel.WeixinChannel",
}
module_path = cls_map.get(channel_name)
if not module_path:
return
try:
parts = module_path.rsplit(".", 1)
module_name, class_name = parts[0], parts[1]
import importlib
module = importlib.import_module(module_name)
wrapper = getattr(module, class_name, None)
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
for cell in wrapper.__closure__:
try:
cell_contents = cell.cell_contents
if isinstance(cell_contents, dict):
cell_contents.clear()
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
break
except ValueError:
pass
except Exception as e:
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
def sigterm_handler_wrap(_signo):
old_handler = signal.getsignal(_signo)
@@ -25,22 +274,8 @@ def sigterm_handler_wrap(_signo):
signal.signal(_signo, func)
def start_channel(channel_name: str):
channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "web", "wechatmp_service", "wechatcom_app", "wework",
const.FEISHU, const.DINGTALK]:
PluginManager().load_plugins()
if conf().get("use_linkai"):
try:
from common import linkai_client
threading.Thread(target=linkai_client.start, args=(channel,)).start()
except Exception as e:
pass
channel.startup()
def run():
global _channel_mgr
try:
# load config
load_config()
@@ -49,33 +284,25 @@ def run():
# kill signal
sigterm_handler_wrap(signal.SIGTERM)
# create channel
channel_name = conf().get("channel_type", "wx")
# Parse channel_type into a list
raw_channel = conf().get("channel_type", "web")
if "--cmd" in sys.argv:
channel_name = "terminal"
if channel_name == "wxy":
os.environ["WECHATY_LOG"] = "warn"
start_channel(channel_name)
# 打印系统运行成功信息
logger.info("")
logger.info("=" * 50)
if conf().get("agent", False):
logger.info("✅ System started successfully!")
logger.info("🐮 Cow Agent is running")
logger.info(f" Channel: {channel_name}")
logger.info(f" Model: {conf().get('model', 'unknown')}")
logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}")
channel_names = ["terminal"]
else:
logger.info("✅ System started successfully!")
logger.info("🤖 ChatBot is running")
logger.info(f" Channel: {channel_name}")
logger.info(f" Model: {conf().get('model', 'unknown')}")
logger.info("=" * 50)
logger.info("")
channel_names = _parse_channel_type(raw_channel)
if not channel_names:
channel_names = ["web"]
# Auto-start web console unless explicitly disabled
web_console_enabled = conf().get("web_console", True)
if web_console_enabled and "web" not in channel_names:
channel_names.append("web")
logger.info(f"[App] Starting channels: {channel_names}")
_channel_mgr = ChannelManager()
_channel_mgr.start(channel_names, first_start=True)
while True:
time.sleep(1)

View File

@@ -6,12 +6,15 @@ import os
from typing import Optional, List
from agent.protocol import Agent, LLMModel, LLMRequest
from models.openai_compatible_bot import OpenAICompatibleBot
from bridge.agent_event_handler import AgentEventHandler
from bridge.agent_initializer import AgentInitializer
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from common.utils import expand_path
from models.openai_compatible_bot import OpenAICompatibleBot
def add_openai_compatible_support(bot_instance):
@@ -20,9 +23,12 @@ def add_openai_compatible_support(bot_instance):
This allows any bot to gain tool calling capability without modifying its code,
as long as it uses OpenAI-compatible API format.
Note: Some bots like ZHIPUAIBot have native tool calling support and don't need enhancement.
"""
if hasattr(bot_instance, 'call_with_tools'):
# Bot already has tool calling support
# Bot already has tool calling support (e.g., ZHIPUAIBot)
logger.debug(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
return bot_instance
# Create a temporary mixin class that combines the bot with OpenAI compatibility
@@ -59,30 +65,71 @@ class AgentLLMModel(LLMModel):
LLM Model adapter that uses COW's existing bot infrastructure
"""
_MODEL_BOT_TYPE_MAP = {
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
"xunfei": const.XUNFEI, const.QWEN: const.QWEN,
const.MODELSCOPE: const.MODELSCOPE,
}
_MODEL_PREFIX_MAP = [
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
("doubao", const.DOUBAO), ("deepseek", const.DEEPSEEK),
]
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
# Get model name directly from config
from config import conf
model_name = conf().get("model", const.GPT_41)
super().__init__(model=model_name)
super().__init__(model=conf().get("model", const.GPT_41))
self.bridge = bridge
self.bot_type = bot_type
self._bot = None
self._use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
self._bot_model = None
@property
def model(self):
from config import conf
return conf().get("model", const.GPT_41)
@model.setter
def model(self, value):
pass
def _resolve_bot_type(self, model_name: str) -> str:
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
from config import conf
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
return const.LINKAI
# Support custom bot type configuration
configured_bot_type = conf().get("bot_type")
if configured_bot_type:
return configured_bot_type
if not model_name or not isinstance(model_name, str):
return const.OPENAI
if model_name in self._MODEL_BOT_TYPE_MAP:
return self._MODEL_BOT_TYPE_MAP[model_name]
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
return const.MiniMax
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
return const.QWEN_DASHSCOPE
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
return const.MOONSHOT
for prefix, btype in self._MODEL_PREFIX_MAP:
if model_name.startswith(prefix):
return btype
return const.OPENAI
@property
def bot(self):
"""Lazy load the bot and enhance it with tool calling if needed"""
if self._bot is None:
# If use_linkai is enabled, use LinkAI bot directly
if self._use_linkai:
self._bot = self.bridge.find_chat_bot(const.LINKAI)
else:
self._bot = self.bridge.get_bot(self.bot_type)
# Automatically add tool calling support if not present
self._bot = add_openai_compatible_support(self._bot)
# Log bot info
bot_name = type(self._bot).__name__
"""Lazy load the bot, re-create when model changes"""
from models.bot_factory import create_bot
cur_model = self.model
if self._bot is None or self._bot_model != cur_model:
bot_type = self._resolve_bot_type(cur_model)
self._bot = create_bot(bot_type)
self._bot = add_openai_compatible_support(self._bot)
self._bot_model = cur_model
return self._bot
def call(self, request: LLMRequest):
@@ -103,12 +150,20 @@ class AgentLLMModel(LLMModel):
# Only pass max_tokens if it's explicitly set
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
if system_prompt:
kwargs['system'] = system_prompt
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
response = self.bot.call_with_tools(**kwargs)
return self._format_response(response)
else:
@@ -127,25 +182,33 @@ class AgentLLMModel(LLMModel):
try:
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled streaming call if available
# Ensure max_tokens is an integer, use default if None
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
# Build kwargs for call_with_tools
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': True,
'max_tokens': max_tokens,
'model': self.model # Pass model parameter
}
# Only pass max_tokens if explicitly set, let the bot use its default
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
# Add system prompt if present
if system_prompt:
kwargs['system'] = system_prompt
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
stream = self.bot.call_with_tools(**kwargs)
# Convert stream format to our expected format
@@ -182,6 +245,9 @@ class AgentBridge:
self.default_agent = None # For backward compatibility (no session_id)
self.agent: Optional[Agent] = None
self.scheduler_initialized = False
# Create helper instances
self.initializer = AgentInitializer(bridge, self)
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
"""
Create the super agent with COW integration
@@ -221,11 +287,13 @@ class AgentBridge:
tools=tools,
max_steps=kwargs.get("max_steps", 15),
output_mode=kwargs.get("output_mode", "logger"),
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
workspace_dir=kwargs.get("workspace_dir"),
skill_manager=kwargs.get("skill_manager"),
enable_skills=kwargs.get("enable_skills", True),
memory_manager=kwargs.get("memory_manager"),
max_context_tokens=kwargs.get("max_context_tokens"),
context_reserve_tokens=kwargs.get("context_reserve_tokens")
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
runtime_info=kwargs.get("runtime_info"),
)
# Log skill loading details
@@ -252,492 +320,19 @@ class AgentBridge:
# Check if agent exists for this session
if session_id not in self.agents:
logger.info(f"[AgentBridge] Creating new agent for session: {session_id}")
self._init_agent_for_session(session_id)
return self.agents[session_id]
def _init_default_agent(self):
"""Initialize default super agent with new prompt system"""
from config import conf
import os
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
# Migrate API keys from config.json to environment variables (if not already set)
self._migrate_config_to_env(workspace_root)
# Load environment variables from secure .env file location
env_file = os.path.expanduser("~/.cow/.env")
if os.path.exists(env_file):
try:
from dotenv import load_dotenv
load_dotenv(env_file, override=True)
logger.info(f"[AgentBridge] Loaded environment variables from {env_file}")
except ImportError:
logger.warning("[AgentBridge] python-dotenv not installed, skipping .env file loading")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load .env file: {e}")
# Initialize workspace and create template files
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
workspace_files = ensure_workspace(workspace_root, create_templates=True)
logger.info(f"[AgentBridge] Workspace initialized at: {workspace_root}")
# Setup memory system
memory_manager = None
memory_tools = []
try:
# Try to initialize memory system
from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool
# 从 config.json 读取 OpenAI 配置
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
# 尝试初始化 OpenAI embedding provider
embedding_provider = None
if openai_api_key:
try:
from agent.memory import create_embedding_provider
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
logger.info(f"[AgentBridge] OpenAI embedding initialized")
except Exception as embed_error:
logger.warning(f"[AgentBridge] OpenAI embedding failed: {embed_error}")
logger.info(f"[AgentBridge] Using keyword-only search")
else:
logger.info(f"[AgentBridge] No OpenAI API key, using keyword-only search")
# 创建 memory config
memory_config = MemoryConfig(workspace_root=workspace_root)
# 创建 memory manager
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# 初始化时执行一次 sync确保数据库有数据
import asyncio
try:
# 尝试在当前事件循环中执行
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
asyncio.create_task(memory_manager.sync())
logger.info("[AgentBridge] Memory sync scheduled")
else:
# 如果没有运行的循环,直接执行
loop.run_until_complete(memory_manager.sync())
logger.info("[AgentBridge] Memory synced successfully")
except RuntimeError:
# 没有事件循环,创建新的
asyncio.run(memory_manager.sync())
logger.info("[AgentBridge] Memory synced successfully")
except Exception as e:
logger.warning(f"[AgentBridge] Memory sync failed: {e}")
# Create memory tools
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
logger.info(f"[AgentBridge] Memory system initialized")
except Exception as e:
logger.warning(f"[AgentBridge] Memory system not available: {e}")
logger.info("[AgentBridge] Continuing without memory features")
# Use ToolManager to dynamically load all available tools
from agent.tools import ToolManager
tool_manager = ToolManager()
tool_manager.load_tools()
# Create tool instances for all available tools
tools = []
file_config = {
"cwd": workspace_root,
"memory_manager": memory_manager
} if memory_manager else {"cwd": workspace_root}
for tool_name in tool_manager.tool_classes.keys():
try:
# Special handling for EnvConfig tool - pass agent_bridge reference
if tool_name == "env_config":
from agent.tools import EnvConfig
tool = EnvConfig({
"agent_bridge": self # Pass self reference for hot reload
})
else:
tool = tool_manager.create_tool(tool_name)
if tool:
# Apply workspace config to file operation tools
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
tool.config = file_config
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
tools.append(tool)
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
# Add memory tools
if memory_tools:
tools.extend(memory_tools)
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
# Initialize scheduler service (once)
if not self.scheduler_initialized:
try:
from agent.tools.scheduler.integration import init_scheduler
if init_scheduler(self):
self.scheduler_initialized = True
logger.info("[AgentBridge] Scheduler service initialized")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to initialize scheduler: {e}")
# Inject scheduler dependencies into SchedulerTool instances
if self.scheduler_initialized:
try:
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
from agent.tools import SchedulerTool
task_store = get_task_store()
scheduler_service = get_scheduler_service()
for tool in tools:
if isinstance(tool, SchedulerTool):
tool.task_store = task_store
tool.scheduler_service = scheduler_service
if not tool.config:
tool.config = {}
tool.config["channel_type"] = conf().get("channel_type", "unknown")
logger.debug("[AgentBridge] Injected scheduler dependencies into SchedulerTool")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies: {e}")
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
# Load context files (SOUL.md, USER.md, etc.)
context_files = load_context_files(workspace_root)
logger.info(f"[AgentBridge] Loaded {len(context_files)} context files: {[f.path for f in context_files]}")
# Check if this is the first conversation
from agent.prompt.workspace import is_first_conversation, mark_conversation_started
is_first = is_first_conversation(workspace_root)
if is_first:
logger.info("[AgentBridge] First conversation detected")
# Build system prompt using new prompt builder
prompt_builder = PromptBuilder(
workspace_dir=workspace_root,
language="zh"
)
# Get runtime info
runtime_info = {
"model": conf().get("model", "unknown"),
"workspace": workspace_root,
"channel": conf().get("channel_type", "unknown") # Get from config
}
system_prompt = prompt_builder.build(
tools=tools,
context_files=context_files,
memory_manager=memory_manager,
runtime_info=runtime_info,
is_first_conversation=is_first
)
# Mark conversation as started (will be saved after first user message)
if is_first:
mark_conversation_started(workspace_root)
logger.info("[AgentBridge] System prompt built successfully")
# Get cost control parameters from config
max_steps = conf().get("agent_max_steps", 20)
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
# Create agent with configured tools and workspace
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root, # Pass workspace to agent for skills loading
enable_skills=True, # Enable skills auto-loading
max_context_tokens=max_context_tokens
)
# Attach memory manager to agent if available
if memory_manager:
agent.memory_manager = memory_manager
logger.info(f"[AgentBridge] Memory manager attached to agent")
# Store as default agent
"""Initialize default super agent"""
agent = self.initializer.initialize_agent(session_id=None)
self.default_agent = agent
def _init_agent_for_session(self, session_id: str):
"""
Initialize agent for a specific session
Reuses the same configuration as default agent
"""
from config import conf
import os
# Get workspace from config
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
# Migrate API keys from config.json to environment variables (if not already set)
self._migrate_config_to_env(workspace_root)
# Load environment variables from secure .env file location
env_file = os.path.expanduser("~/.cow/.env")
if os.path.exists(env_file):
try:
from dotenv import load_dotenv
load_dotenv(env_file, override=True)
logger.debug(f"[AgentBridge] Loaded environment variables from {env_file} for session {session_id}")
except ImportError:
logger.warning(f"[AgentBridge] python-dotenv not installed, skipping .env file loading for session {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load .env file for session {session_id}: {e}")
# Migrate API keys from config.json to environment variables (if not already set)
self._migrate_config_to_env(workspace_root)
# Initialize workspace
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
workspace_files = ensure_workspace(workspace_root, create_templates=True)
# Setup memory system
memory_manager = None
memory_tools = []
try:
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
from agent.tools import MemorySearchTool, MemoryGetTool
# 从 config.json 读取 OpenAI 配置
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
# 尝试初始化 OpenAI embedding provider
embedding_provider = None
if openai_api_key:
try:
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
logger.debug(f"[AgentBridge] OpenAI embedding initialized for session {session_id}")
except Exception as embed_error:
logger.warning(f"[AgentBridge] OpenAI embedding failed for session {session_id}: {embed_error}")
logger.info(f"[AgentBridge] Using keyword-only search for session {session_id}")
else:
logger.debug(f"[AgentBridge] No OpenAI API key, using keyword-only search for session {session_id}")
# 创建 memory config
memory_config = MemoryConfig(workspace_root=workspace_root)
# 创建 memory manager
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# 初始化时执行一次 sync确保数据库有数据
import asyncio
try:
# 尝试在当前事件循环中执行
loop = asyncio.get_event_loop()
if loop.is_running():
# 如果事件循环正在运行,创建任务
asyncio.create_task(memory_manager.sync())
logger.debug(f"[AgentBridge] Memory sync scheduled for session {session_id}")
else:
# 如果没有运行的循环,直接执行
loop.run_until_complete(memory_manager.sync())
logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
except RuntimeError:
# 没有事件循环,创建新的
asyncio.run(memory_manager.sync())
logger.debug(f"[AgentBridge] Memory synced successfully for session {session_id}")
except Exception as sync_error:
logger.warning(f"[AgentBridge] Memory sync failed for session {session_id}: {sync_error}")
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
except Exception as e:
logger.warning(f"[AgentBridge] Memory system not available for session {session_id}: {e}")
import traceback
logger.warning(f"[AgentBridge] Memory init traceback: {traceback.format_exc()}")
# Load tools
from agent.tools import ToolManager
tool_manager = ToolManager()
tool_manager.load_tools()
tools = []
file_config = {
"cwd": workspace_root,
"memory_manager": memory_manager
} if memory_manager else {"cwd": workspace_root}
for tool_name in tool_manager.tool_classes.keys():
try:
tool = tool_manager.create_tool(tool_name)
if tool:
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
tool.config = file_config
tool.cwd = file_config.get("cwd", tool.cwd if hasattr(tool, 'cwd') else None)
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name} for session {session_id}: {e}")
if memory_tools:
tools.extend(memory_tools)
# Initialize scheduler service (once, if not already initialized)
if not self.scheduler_initialized:
try:
from agent.tools.scheduler.integration import init_scheduler
if init_scheduler(self):
self.scheduler_initialized = True
logger.debug(f"[AgentBridge] Scheduler service initialized for session {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to initialize scheduler for session {session_id}: {e}")
# Inject scheduler dependencies into SchedulerTool instances
if self.scheduler_initialized:
try:
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
from agent.tools import SchedulerTool
task_store = get_task_store()
scheduler_service = get_scheduler_service()
for tool in tools:
if isinstance(tool, SchedulerTool):
tool.task_store = task_store
tool.scheduler_service = scheduler_service
if not tool.config:
tool.config = {}
tool.config["channel_type"] = conf().get("channel_type", "unknown")
logger.debug(f"[AgentBridge] Injected scheduler dependencies for session {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to inject scheduler dependencies for session {session_id}: {e}")
# Load context files
context_files = load_context_files(workspace_root)
# Initialize skill manager
skill_manager = None
try:
from agent.skills import SkillManager
skill_manager = SkillManager(workspace_dir=workspace_root)
logger.debug(f"[AgentBridge] Initialized SkillManager with {len(skill_manager.skills)} skills for session {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to initialize SkillManager for session {session_id}: {e}")
# Check if this is the first conversation
from agent.prompt.workspace import is_first_conversation, mark_conversation_started
is_first = is_first_conversation(workspace_root)
# Build system prompt
prompt_builder = PromptBuilder(
workspace_dir=workspace_root,
language="zh"
)
# Get current time and timezone info
import datetime
import time
now = datetime.datetime.now()
# Get timezone info
try:
offset = -time.timezone if not time.daylight else -time.altzone
hours = offset // 3600
minutes = (offset % 3600) // 60
if minutes:
timezone_name = f"UTC{hours:+03d}:{minutes:02d}"
else:
timezone_name = f"UTC{hours:+03d}"
except Exception:
timezone_name = "UTC"
# Chinese weekday mapping
weekday_map = {
'Monday': '星期一',
'Tuesday': '星期二',
'Wednesday': '星期三',
'Thursday': '星期四',
'Friday': '星期五',
'Saturday': '星期六',
'Sunday': '星期日'
}
weekday_zh = weekday_map.get(now.strftime("%A"), now.strftime("%A"))
runtime_info = {
"model": conf().get("model", "unknown"),
"workspace": workspace_root,
"channel": conf().get("channel_type", "unknown"),
"current_time": now.strftime("%Y-%m-%d %H:%M:%S"),
"weekday": weekday_zh,
"timezone": timezone_name
}
system_prompt = prompt_builder.build(
tools=tools,
context_files=context_files,
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
is_first_conversation=is_first
)
if is_first:
mark_conversation_started(workspace_root)
# Get cost control parameters from config
max_steps = conf().get("agent_max_steps", 20)
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
# Create agent for this session
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root,
skill_manager=skill_manager,
enable_skills=True,
max_context_tokens=max_context_tokens
)
if memory_manager:
agent.memory_manager = memory_manager
# Store agent for this session
"""Initialize agent for a specific session"""
agent = self.initializer.initialize_agent(session_id=session_id)
self.agents[session_id] = agent
logger.info(f"[AgentBridge] Agent created for session: {session_id}")
def agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
@@ -753,9 +348,10 @@ class AgentBridge:
Returns:
Reply object
"""
session_id = None
agent = None
try:
# Extract session_id from context for user isolation
session_id = None
if context:
session_id = context.kwargs.get("session_id") or context.get("session_id")
@@ -764,6 +360,9 @@ class AgentBridge:
if not agent:
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
# Create event handler for logging and channel communication
event_handler = AgentEventHandler(context=context, original_callback=on_event)
# Filter tools based on context
original_tools = agent.tools
filtered_tools = original_tools
@@ -785,17 +384,45 @@ class AgentBridge:
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
break
# Pass context metadata to model for downstream API requests
if context and hasattr(agent, 'model'):
agent.model.channel_type = context.get("channel_type", "")
agent.model.session_id = session_id or ""
# Store session_id on agent so executor can clear DB on fatal errors
agent._current_session_id = session_id
try:
# Use agent's run_stream method
# Use agent's run_stream method with event handler
response = agent.run_stream(
user_message=query,
on_event=on_event,
on_event=event_handler.handle_event,
clear_history=clear_history
)
finally:
# Restore original tools
if context and context.get("is_scheduled_task"):
agent.tools = original_tools
# Log execution summary
event_handler.log_summary()
# Persist new messages generated during this run
if session_id:
channel_type = (context.get("channel_type") or "") if context else ""
new_messages = getattr(agent, '_last_run_new_messages', [])
if new_messages:
self._persist_messages(session_id, list(new_messages), channel_type)
else:
with agent.messages_lock:
msg_count = len(agent.messages)
if msg_count == 0:
try:
from agent.memory import get_conversation_store
get_conversation_store().clear_session(session_id)
logger.info(f"[AgentBridge] Cleared DB for recovered session: {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}")
# Check if there are files to send (from read tool)
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
@@ -815,6 +442,18 @@ class AgentBridge:
except Exception as e:
logger.error(f"Agent reply error: {e}")
# If the agent cleared its messages due to format error / overflow,
# also purge the DB so the next request starts clean.
if session_id and agent:
try:
with agent.messages_lock:
msg_count = len(agent.messages)
if msg_count == 0:
from agent.memory import get_conversation_store
get_conversation_store().clear_session(session_id)
logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}")
except Exception as db_err:
logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}")
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
@@ -843,17 +482,18 @@ class AgentBridge:
reply.text_content = text_response # Store accompanying text
return reply
# For documents (PDF, Excel, Word, PPT), use FILE type
if file_type == "document":
# For all file types (document, video, audio), use FILE type
if file_type in ["document", "video", "audio"]:
file_url = f"file://{file_path}"
logger.info(f"[AgentBridge] Sending document: {file_url}")
logger.info(f"[AgentBridge] Sending {file_type}: {file_url}")
reply = Reply(ReplyType.FILE, file_url)
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
# Attach text message if present
if text_response:
reply.text_content = text_response
return reply
# For other files (video, audio), we need channel-specific handling
# For now, return text with file info
# TODO: Implement video/audio sending when channel supports it
# For other unknown file types, return text with file info
message = text_response or file_info.get("message", "文件已准备")
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
return Reply(ReplyType.TEXT, message)
@@ -878,7 +518,7 @@ class AgentBridge:
}
# Use fixed secure location for .env file
env_file = os.path.expanduser("~/.cow/.env")
env_file = expand_path("~/.cow/.env")
# Read existing env vars from .env file
existing_env_vars = {}
@@ -931,6 +571,32 @@ class AgentBridge:
except Exception as e:
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
def _persist_messages(
self, session_id: str, new_messages: list, channel_type: str = ""
) -> None:
"""
Persist new messages to the conversation store after each agent run.
Failures are logged but never propagate — they must not interrupt replies.
"""
if not new_messages:
return
try:
from config import conf
if not conf().get("conversation_persistence", True):
return
except Exception:
pass
try:
from agent.memory import get_conversation_store
get_conversation_store().append_messages(
session_id, new_messages, channel_type=channel_type
)
except Exception as e:
logger.warning(
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
)
def clear_session(self, session_id: str):
"""
Clear a specific session's agent and conversation history
@@ -950,39 +616,70 @@ class AgentBridge:
def refresh_all_skills(self) -> int:
"""
Refresh skills in all agent instances after environment variable changes.
This allows hot-reload of skills without restarting the agent.
Refresh skills and conditional tools in all agent instances after
environment variable changes. This allows hot-reload without restarting.
Returns:
Number of agent instances refreshed
"""
import os
from dotenv import load_dotenv
from config import conf
# Reload environment variables from .env file
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
env_file = os.path.join(workspace_root, '.env')
if os.path.exists(env_file):
load_dotenv(env_file, override=True)
logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
refreshed_count = 0
# Refresh default agent
if self.default_agent and hasattr(self.default_agent, 'skill_manager'):
self.default_agent.skill_manager.refresh_skills()
refreshed_count += 1
logger.info("[AgentBridge] Refreshed skills in default agent")
# Refresh all session agents
# Collect all agent instances to refresh
agents_to_refresh = []
if self.default_agent:
agents_to_refresh.append(("default", self.default_agent))
for session_id, agent in self.agents.items():
if hasattr(agent, 'skill_manager'):
agents_to_refresh.append((session_id, agent))
for label, agent in agents_to_refresh:
# Refresh skills
if hasattr(agent, 'skill_manager') and agent.skill_manager:
agent.skill_manager.refresh_skills()
refreshed_count += 1
# Refresh conditional tools (e.g. web_search depends on API keys)
self._refresh_conditional_tools(agent)
refreshed_count += 1
if refreshed_count > 0:
logger.info(f"[AgentBridge] Refreshed skills in {refreshed_count} agent instance(s)")
return refreshed_count
logger.info(f"[AgentBridge] Refreshed skills & tools in {refreshed_count} agent instance(s)")
return refreshed_count
@staticmethod
def _refresh_conditional_tools(agent):
"""
Add or remove conditional tools based on current environment variables.
For example, web_search should only be present when BOCHA_API_KEY or
LINKAI_API_KEY is set.
"""
try:
from agent.tools.web_search.web_search import WebSearch
has_tool = any(t.name == "web_search" for t in agent.tools)
available = WebSearch.is_available()
if available and not has_tool:
# API key was added - inject the tool
tool = WebSearch()
tool.model = agent.model
agent.tools.append(tool)
logger.info("[AgentBridge] web_search tool added (API key now available)")
elif not available and has_tool:
# API key was removed - remove the tool
agent.tools = [t for t in agent.tools if t.name != "web_search"]
logger.info("[AgentBridge] web_search tool removed (API key no longer available)")
except Exception as e:
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")

View File

@@ -0,0 +1,115 @@
"""
Agent Event Handler - Handles agent events and thinking process output
"""
from common.log import logger
class AgentEventHandler:
"""
Handles agent events and optionally sends intermediate messages to channel
"""
def __init__(self, context=None, original_callback=None):
"""
Initialize event handler
Args:
context: COW context (for accessing channel)
original_callback: Original event callback to chain
"""
self.context = context
self.original_callback = original_callback
# Get channel for sending intermediate messages
self.channel = None
if context:
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
# Track current thinking for channel output
self.current_thinking = ""
self.turn_number = 0
def handle_event(self, event):
"""
Main event handler
Args:
event: Event dict with type and data
"""
event_type = event.get("type")
data = event.get("data", {})
# Dispatch to specific handlers
if event_type == "turn_start":
self._handle_turn_start(data)
elif event_type == "message_update":
self._handle_message_update(data)
elif event_type == "message_end":
self._handle_message_end(data)
elif event_type == "tool_execution_start":
self._handle_tool_execution_start(data)
elif event_type == "tool_execution_end":
self._handle_tool_execution_end(data)
# Call original callback if provided
if self.original_callback:
self.original_callback(event)
def _handle_turn_start(self, data):
"""Handle turn start event"""
self.turn_number = data.get("turn", 0)
self.has_tool_calls_in_turn = False
self.current_thinking = ""
def _handle_message_update(self, data):
"""Handle message update event (streaming text)"""
delta = data.get("delta", "")
self.current_thinking += delta
def _handle_message_end(self, data):
"""Handle message end event"""
tool_calls = data.get("tool_calls", [])
# Only send thinking process if followed by tool calls
if tool_calls:
if self.current_thinking.strip():
logger.info(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
# Send thinking process to channel
self._send_to_channel(f"{self.current_thinking.strip()}")
else:
# No tool calls = final response (logged at agent_stream level)
if self.current_thinking.strip():
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
self.current_thinking = ""
def _handle_tool_execution_start(self, data):
"""Handle tool execution start event - logged by agent_stream.py"""
pass
def _handle_tool_execution_end(self, data):
"""Handle tool execution end event - logged by agent_stream.py"""
pass
def _send_to_channel(self, message):
"""
Try to send intermediate message to channel.
Skipped in SSE mode because thinking text is already streamed via on_event.
"""
if self.context and self.context.get("on_event"):
return
if self.channel:
try:
from bridge.reply import Reply, ReplyType
reply = Reply(ReplyType.TEXT, message)
self.channel._send(reply, self.context)
except Exception as e:
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
def log_summary(self):
"""Log execution summary - simplified"""
# Summary removed as per user request
# Real-time logging during execution is sufficient
pass

584
bridge/agent_initializer.py Normal file
View File

@@ -0,0 +1,584 @@
"""
Agent Initializer - Handles agent initialization logic
"""
import os
import asyncio
import datetime
import time
from typing import Optional, List
from agent.protocol import Agent
from agent.tools import ToolManager
from common.log import logger
from common.utils import expand_path
class AgentInitializer:
"""
Handles agent initialization including:
- Workspace setup
- Memory system initialization
- Tool loading
- System prompt building
"""
def __init__(self, bridge, agent_bridge):
"""
Initialize agent initializer
Args:
bridge: COW bridge instance
agent_bridge: AgentBridge instance (for create_agent method)
"""
self.bridge = bridge
self.agent_bridge = agent_bridge
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
"""
Initialize agent for a session
Args:
session_id: Session ID (None for default agent)
Returns:
Initialized agent instance
"""
from config import conf
# Get workspace from config
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
# Migrate API keys
self._migrate_config_to_env(workspace_root)
# Load environment variables
self._load_env_file()
# Initialize workspace
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
workspace_files = ensure_workspace(workspace_root, create_templates=True)
if session_id is None:
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
# Setup memory system
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
# Load tools
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
# Initialize scheduler if needed
self._initialize_scheduler(tools, session_id)
# Load context files
context_files = load_context_files(workspace_root)
# Initialize skill manager
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
# Build system prompt
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
runtime_info = self._get_runtime_info(workspace_root)
system_prompt = prompt_builder.build(
tools=tools,
context_files=context_files,
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
)
# Get cost control parameters
from config import conf
max_steps = conf().get("agent_max_steps", 20)
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
# Create agent
agent = self.agent_bridge.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root,
skill_manager=skill_manager,
enable_skills=True,
max_context_tokens=max_context_tokens,
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
)
# Attach memory manager and share LLM model for summarization
if memory_manager:
agent.memory_manager = memory_manager
if hasattr(agent, 'model') and agent.model:
memory_manager.flush_manager.llm_model = agent.model
# Restore persisted conversation history for this session
if session_id:
self._restore_conversation_history(agent, session_id)
# Start daily memory flush timer (once, on first agent init regardless of session)
self._start_daily_flush_timer()
return agent
def _restore_conversation_history(self, agent, session_id: str) -> None:
"""
Load persisted conversation messages from SQLite and inject them
into the agent's in-memory message list.
Only user text and assistant text are restored. Tool call chains
(tool_use / tool_result) are stripped out because:
1. They are intermediate process, the value is already in the final
assistant text reply.
2. They consume massive context tokens (often 80%+ of history).
3. Different models have incompatible tool message formats, so
restoring tool chains across model switches causes 400 errors.
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
"""
from config import conf
if not conf().get("conversation_persistence", True):
return
try:
from agent.memory import get_conversation_store
store = get_conversation_store()
max_turns = conf().get("agent_max_context_turns", 20)
restore_turns = max(3, max_turns // 6)
saved = store.load_messages(session_id, max_turns=restore_turns)
if saved:
filtered = self._filter_text_only_messages(saved)
if filtered:
with agent.messages_lock:
agent.messages = filtered
logger.debug(
f"[AgentInitializer] Restored {len(filtered)} text messages "
f"(from {len(saved)} total, {restore_turns} turns cap) "
f"for session={session_id}"
)
except Exception as e:
logger.warning(
f"[AgentInitializer] Failed to restore conversation history for "
f"session={session_id}: {e}"
)
@staticmethod
def _filter_text_only_messages(messages: list) -> list:
"""
Extract clean user/assistant turn pairs from raw message history.
Groups messages into turns (each starting with a real user query),
then keeps only:
- The first user text in each turn (the actual user input)
- The last assistant text in each turn (the final answer)
All tool_use, tool_result, intermediate assistant thoughts, and
internal hint messages injected by the agent loop are discarded.
"""
def _extract_text(content) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def _is_real_user_msg(msg: dict) -> bool:
"""True for actual user input, False for tool_result or internal hints."""
if msg.get("role") != "user":
return False
content = msg.get("content")
if isinstance(content, list):
has_tool_result = any(
isinstance(b, dict) and b.get("type") == "tool_result"
for b in content
)
if has_tool_result:
return False
text = _extract_text(content)
return bool(text)
# Group into turns: each turn starts with a real user message
turns = []
current_turn = None
for msg in messages:
if _is_real_user_msg(msg):
if current_turn is not None:
turns.append(current_turn)
current_turn = {"user": msg, "assistants": []}
elif current_turn is not None and msg.get("role") == "assistant":
text = _extract_text(msg.get("content"))
if text:
current_turn["assistants"].append(text)
if current_turn is not None:
turns.append(current_turn)
# Build result: one user msg + one assistant msg per turn
filtered = []
for turn in turns:
user_text = _extract_text(turn["user"].get("content"))
if not user_text:
continue
filtered.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
if turn["assistants"]:
final_reply = turn["assistants"][-1]
filtered.append({
"role": "assistant",
"content": [{"type": "text", "text": final_reply}]
})
return filtered
def _load_env_file(self):
"""Load environment variables from .env file"""
env_file = expand_path("~/.cow/.env")
if os.path.exists(env_file):
try:
from dotenv import load_dotenv
load_dotenv(env_file, override=True)
except ImportError:
logger.warning("[AgentInitializer] python-dotenv not installed")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
"""
Setup memory system
Returns:
(memory_manager, memory_tools) tuple
"""
memory_manager = None
memory_tools = []
try:
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
from agent.tools import MemorySearchTool, MemoryGetTool
from config import conf
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
embedding_provider = None
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
if session_id is None:
logger.info("[AgentInitializer] OpenAI embedding initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
if embedding_provider is None:
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="linkai",
model="text-embedding-3-small",
api_key=linkai_api_key,
api_base=f"{linkai_api_base}/v1"
)
if session_id is None:
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
except Exception as e:
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
# Create memory manager
memory_config = MemoryConfig(workspace_root=workspace_root)
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# Sync memory
self._sync_memory(memory_manager, session_id)
# Create memory tools
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
if session_id is None:
logger.info("[AgentInitializer] Memory system initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
return memory_manager, memory_tools
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
"""Sync memory database"""
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
raise RuntimeError("Event loop is closed")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
if loop.is_running():
asyncio.create_task(memory_manager.sync())
else:
loop.run_until_complete(memory_manager.sync())
except Exception as e:
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
"""Load all tools"""
tool_manager = ToolManager()
tool_manager.load_tools()
tools = []
file_config = {
"cwd": workspace_root,
"memory_manager": memory_manager
} if memory_manager else {"cwd": workspace_root}
for tool_name in tool_manager.tool_classes.keys():
try:
# Skip web_search if no API key is available
if tool_name == "web_search":
from agent.tools.web_search.web_search import WebSearch
if not WebSearch.is_available():
logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY")
continue
# Special handling for EnvConfig tool
if tool_name == "env_config":
from agent.tools import EnvConfig
tool = EnvConfig({"agent_bridge": self.agent_bridge})
else:
tool = tool_manager.create_tool(tool_name)
if tool:
# Apply workspace config to file operation tools
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch']:
tool.config = file_config
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
# Add memory tools
if memory_tools:
tools.extend(memory_tools)
if session_id is None:
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
if session_id is None:
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
return tools
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
"""Initialize scheduler service if needed"""
if not self.agent_bridge.scheduler_initialized:
try:
from agent.tools.scheduler.integration import init_scheduler
if init_scheduler(self.agent_bridge):
self.agent_bridge.scheduler_initialized = True
if session_id is None:
logger.info("[AgentInitializer] Scheduler service initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
# Inject scheduler dependencies
if self.agent_bridge.scheduler_initialized:
try:
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
from agent.tools import SchedulerTool
from config import conf
task_store = get_task_store()
scheduler_service = get_scheduler_service()
for tool in tools:
if isinstance(tool, SchedulerTool):
tool.task_store = task_store
tool.scheduler_service = scheduler_service
if not tool.config:
tool.config = {}
raw_ct = conf().get("channel_type", "unknown")
if isinstance(raw_ct, list):
ct = raw_ct[0] if raw_ct else "unknown"
elif isinstance(raw_ct, str) and "," in raw_ct:
ct = raw_ct.split(",")[0].strip()
else:
ct = raw_ct
tool.config["channel_type"] = ct
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
"""Initialize skill manager"""
try:
from agent.skills import SkillManager
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
return skill_manager
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
return None
def _get_runtime_info(self, workspace_root: str):
"""Get runtime information with dynamic time support"""
from config import conf
def get_current_time():
"""Get current time dynamically - called each time system prompt is accessed"""
now = datetime.datetime.now()
# Get timezone info
try:
offset = -time.timezone if not time.daylight else -time.altzone
hours = offset // 3600
minutes = (offset % 3600) // 60
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
except Exception:
timezone_name = "UTC"
# Chinese weekday mapping
weekday_map = {
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
}
weekday_zh = weekday_map.get(now.strftime("%A"), now.strftime("%A"))
return {
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
'weekday': weekday_zh,
'timezone': timezone_name
}
return {
"model": conf().get("model", "unknown"),
"workspace": workspace_root,
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
"_get_current_time": get_current_time # Dynamic time function
}
def _migrate_config_to_env(self, workspace_root: str):
"""Migrate API keys from config.json to .env file"""
from config import conf
key_mapping = {
"open_ai_api_key": "OPENAI_API_KEY",
"open_ai_api_base": "OPENAI_API_BASE",
"gemini_api_key": "GEMINI_API_KEY",
"claude_api_key": "CLAUDE_API_KEY",
"linkai_api_key": "LINKAI_API_KEY",
}
env_file = expand_path("~/.cow/.env")
# Read existing env vars
existing_env_vars = {}
if os.path.exists(env_file):
try:
with open(env_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, _ = line.split('=', 1)
existing_env_vars[key.strip()] = True
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
# Check which keys need migration
keys_to_migrate = {}
for config_key, env_key in key_mapping.items():
if env_key in existing_env_vars:
continue
value = conf().get(config_key, "")
if value and value.strip():
keys_to_migrate[env_key] = value.strip()
# Write new keys
if keys_to_migrate:
try:
env_dir = os.path.dirname(env_file)
if not os.path.exists(env_dir):
os.makedirs(env_dir, exist_ok=True)
if not os.path.exists(env_file):
open(env_file, 'a').close()
with open(env_file, 'a', encoding='utf-8') as f:
f.write('\n# Auto-migrated from config.json\n')
for key, value in keys_to_migrate.items():
f.write(f'{key}={value}\n')
os.environ[key] = value
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
def _start_daily_flush_timer(self):
"""Start a background thread that flushes all agents' memory daily at 23:55."""
if getattr(self.agent_bridge, '_daily_flush_started', False):
return
self.agent_bridge._daily_flush_started = True
import threading
def _daily_flush_loop():
while True:
try:
now = datetime.datetime.now()
target = now.replace(hour=23, minute=55, second=0, microsecond=0)
if target <= now:
target += datetime.timedelta(days=1)
wait_seconds = (target - now).total_seconds()
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M')} (in {wait_seconds/3600:.1f}h)")
time.sleep(wait_seconds)
self._flush_all_agents()
except Exception as e:
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
time.sleep(3600)
t = threading.Thread(target=_daily_flush_loop, daemon=True)
t.start()
def _flush_all_agents(self):
"""Flush memory for all active agent sessions."""
agents = []
if self.agent_bridge.default_agent:
agents.append(("default", self.agent_bridge.default_agent))
for sid, agent in self.agent_bridge.agents.items():
agents.append((sid, agent))
if not agents:
return
flushed = 0
for label, agent in agents:
try:
if not agent.memory_manager:
continue
with agent.messages_lock:
messages = list(agent.messages)
if not messages:
continue
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
if result:
flushed += 1
except Exception as e:
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
if flushed:
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")

View File

@@ -13,7 +13,7 @@ from voice.factory import create_voice
class Bridge(object):
def __init__(self):
self.btype = {
"chat": const.CHATGPT,
"chat": const.OPENAI,
"voice_to_text": conf().get("voice_to_text", "openai"),
"text_to_voice": conf().get("text_to_voice", "google"),
"translate": conf().get("translate", "baidu"),
@@ -24,6 +24,13 @@ class Bridge(object):
self.btype["chat"] = bot_type
else:
model_type = conf().get("model") or const.GPT_41_MINI
# Ensure model_type is string to prevent AttributeError when using startswith()
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
if not isinstance(model_type, str):
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
model_type = str(model_type)
if model_type in ["text-davinci-003"]:
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
@@ -36,6 +43,9 @@ class Bridge(object):
self.btype["chat"] = const.QWEN
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
self.btype["chat"] = const.QWEN_DASHSCOPE
# Support Qwen3 and other DashScope models
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
self.btype["chat"] = const.QWEN_DASHSCOPE
if model_type and model_type.startswith("gemini"):
self.btype["chat"] = const.GEMINI
if model_type and model_type.startswith("glm"):
@@ -43,16 +53,22 @@ class Bridge(object):
if model_type and model_type.startswith("claude"):
self.btype["chat"] = const.CLAUDEAPI
if model_type in ["claude"]:
self.btype["chat"] = const.CLAUDEAI
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
self.btype["chat"] = const.MOONSHOT
if model_type and model_type.startswith("kimi"):
self.btype["chat"] = const.MOONSHOT
if model_type and model_type.startswith("doubao"):
self.btype["chat"] = const.DOUBAO
if model_type and model_type.startswith("deepseek"):
self.btype["chat"] = const.DEEPSEEK
if model_type in [const.MODELSCOPE]:
self.btype["chat"] = const.MODELSCOPE
if model_type in ["abab6.5-chat"]:
# MiniMax models
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
self.btype["chat"] = const.MiniMax
if conf().get("use_linkai") and conf().get("linkai_api_key"):

View File

@@ -13,12 +13,44 @@ class Channel(object):
channel_type = ""
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
def __init__(self):
import threading
self._startup_event = threading.Event()
self._startup_error = None
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
def startup(self):
"""
init channel
"""
raise NotImplementedError
def report_startup_success(self):
self._startup_error = None
self._startup_event.set()
def report_startup_error(self, error: str):
self._startup_error = error
self._startup_event.set()
def wait_startup(self, timeout: float = 3) -> (bool, str):
"""
Wait for channel startup result.
Returns (success: bool, error_msg: str).
"""
ready = self._startup_event.wait(timeout=timeout)
if not ready:
return True, ""
if self._startup_error:
return False, self._startup_error
return True, ""
def stop(self):
"""
stop channel gracefully, called before restart
"""
pass
def handle_text(self, msg):
"""
process received msg
@@ -51,11 +83,14 @@ class Channel(object):
if context and "channel_type" not in context:
context["channel_type"] = self.channel_type
# Read on_event callback injected by the channel (e.g. web SSE)
on_event = context.get("on_event") if context else None
# Use agent bridge to handle the query
return Bridge().fetch_agent_reply(
query=query,
context=context,
on_event=None,
on_event=on_event,
clear_history=False
)
except Exception as e:

View File

@@ -12,16 +12,7 @@ def create_channel(channel_type) -> Channel:
:return: channel instance
"""
ch = Channel()
if channel_type == "wx":
from channel.wechat.wechat_channel import WechatChannel
ch = WechatChannel()
elif channel_type == "wxy":
from channel.wechat.wechaty_channel import WechatyChannel
ch = WechatyChannel()
elif channel_type == "wcf":
from channel.wechat.wcf_channel import WechatfChannel
ch = WechatfChannel()
elif channel_type == "terminal":
if channel_type == "terminal":
from channel.terminal.terminal_channel import TerminalChannel
ch = TerminalChannel()
elif channel_type == 'web':
@@ -36,15 +27,22 @@ def create_channel(channel_type) -> Channel:
elif channel_type == "wechatcom_app":
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
ch = WechatComAppChannel()
elif channel_type == "wework":
from channel.wework.wework_channel import WeworkChannel
ch = WeworkChannel()
elif channel_type == const.FEISHU:
from channel.feishu.feishu_channel import FeiShuChanel
ch = FeiShuChanel()
elif channel_type == const.DINGTALK:
from channel.dingtalk.dingtalk_channel import DingTalkChanel
ch = DingTalkChanel()
elif channel_type == const.WECOM_BOT:
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
ch = WecomBotChannel()
elif channel_type == const.QQ:
from channel.qq.qq_channel import QQChannel
ch = QQChannel()
elif channel_type in (const.WEIXIN, "wx"):
from channel.weixin.weixin_channel import WeixinChannel
ch = WeixinChannel()
channel_type = const.WEIXIN
else:
raise RuntimeError
ch.channel_type = channel_type

View File

@@ -24,11 +24,17 @@ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
class ChatChannel(Channel):
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉正在执行的不会被取消
sessions = {} # 用于控制并发每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
def __init__(self):
super().__init__()
# Instance-level attributes so each channel subclass has its own
# independent session queue and lock. Previously these were class-level,
# which caused contexts from one channel (e.g. Feishu) to be consumed
# by another channel's consume() thread (e.g. Web), leading to errors
# like "No request_id found in context".
self.futures = {}
self.sessions = {}
self.lock = threading.Lock()
_thread = threading.Thread(target=self.consume)
_thread.setDaemon(True)
_thread.start()
@@ -37,9 +43,8 @@ class ChatChannel(Channel):
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
# context首次传入时origin_ctype是None,
# 引入的起因是当输入语音时会嵌套生成两个context第一步语音转文本第二步通过文本生成文字回复。
# origin_ctype用于第二步文本回复时判断是否需要匹配前缀如果是私聊的语音就不需要匹配前缀
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
# context首次传入时receiver是None根据类型设置receiver
@@ -426,7 +431,7 @@ class ChatChannel(Channel):
if session_id not in self.sessions:
self.sessions[session_id] = [
Dequeue(),
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令

View File

@@ -1,5 +1,5 @@
"""
本类表示聊天消息用于对itchat和wechaty的消息进行统一的封装。
Unified chat message class for different channel implementations.
填好必填项(群聊6个非群聊8个)即可接入ChatChannel并支持插件参考TerminalChannel

View File

@@ -21,6 +21,7 @@ from dingtalk_stream.card_replier import CardReplier
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel
from common.utils import expand_path
from channel.dingtalk.dingtalk_message import DingTalkMessage
from common.expired_dict import ExpiredDict
from common.log import logger
@@ -89,13 +90,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
dingtalk_client_secret = conf().get('dingtalk_client_secret')
def setup_logger(self):
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
# Suppress verbose logs from dingtalk_stream SDK
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
return logging.getLogger("DingTalk")
def __init__(self):
super().__init__()
@@ -103,6 +100,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
self.logger = self.setup_logger()
# 历史消息id暂存用于幂等控制
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
self._stream_client = None
self._running = False
self._event_loop = None
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
self.dingtalk_client_id, self.dingtalk_client_secret))
# 无需群校验和前缀
@@ -115,12 +115,130 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
# Robot code cache (extracted from incoming messages)
self._robot_code = None
def _open_connection(self, client):
"""
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
Returns (connection_dict, error_str). On success error_str is empty; on failure
connection_dict is None and error_str contains a human-readable message.
"""
try:
resp = requests.post(
"https://api.dingtalk.com/v1.0/gateway/connections/open",
headers={"Content-Type": "application/json", "Accept": "application/json"},
json={
"clientId": client.credential.client_id,
"clientSecret": client.credential.client_secret,
"subscriptions": [{"type": "CALLBACK",
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
"ua": "dingtalk-sdk-python/cow",
"localIp": "",
},
timeout=10,
)
body = resp.json()
if not resp.ok:
code = body.get("code", resp.status_code)
message = body.get("message", resp.reason)
return None, f"open connection failed: [{code}] {message}"
return body, ""
except Exception as e:
return None, f"open connection failed: {e}"
def startup(self):
import asyncio
self.dingtalk_client_id = conf().get('dingtalk_client_id')
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
self._running = True
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential)
self._stream_client = client
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
client.start_forever()
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
# Run the connection loop ourselves instead of delegating to client.start(),
# so we can get detailed error messages and respond to stop() quickly.
import urllib.parse as _urlparse
import websockets as _ws
import json as _json
client.pre_start()
_first_connect = True
while self._running:
# Open connection using our own request so we get detailed error info.
connection, err_msg = self._open_connection(client)
if connection is None:
if _first_connect:
logger.warning(f"[DingTalk] {err_msg}")
self.report_startup_error(err_msg)
_first_connect = False
else:
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
# Interruptible sleep: checks _running every 100ms.
for _ in range(100):
if not self._running:
break
time.sleep(0.1)
continue
if _first_connect:
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
self.report_startup_success()
_first_connect = False
else:
logger.info("[DingTalk] Reconnected to DingTalk stream")
# Run the WebSocket session in an asyncio loop.
uri = '%s?ticket=%s' % (
connection['endpoint'],
_urlparse.quote_plus(connection['ticket'])
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._event_loop = loop
try:
async def _session():
async with _ws.connect(uri) as websocket:
client.websocket = websocket
async for raw_message in websocket:
json_message = _json.loads(raw_message)
result = await client.route_message(json_message)
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
break
loop.run_until_complete(_session())
except (KeyboardInterrupt, SystemExit):
logger.info("[DingTalk] Session loop received stop signal, exiting")
break
except Exception as e:
if not self._running:
break
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
for _ in range(30):
if not self._running:
break
time.sleep(0.1)
finally:
self._event_loop = None
try:
loop.close()
except Exception:
pass
logger.info("[DingTalk] Startup loop exited")
def stop(self):
logger.info("[DingTalk] stop() called, setting _running=False")
self._running = False
loop = self._event_loop
if loop and not loop.is_closed():
try:
loop.call_soon_threadsafe(loop.stop)
logger.info("[DingTalk] Sent stop signal to event loop")
except Exception as e:
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
self._stream_client = None
logger.info("[DingTalk] stop() completed")
def get_access_token(self):
"""
@@ -276,7 +394,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
# 保存到临时文件
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
temp_file = os.path.join(tmp_dir, file_name)
@@ -457,23 +575,21 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
async def process(self, callback: dingtalk_stream.CallbackMessage):
try:
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
# 缓存 robot_code用于后续图片下载
if hasattr(incoming_message, 'robot_code'):
self._robot_code_cache = incoming_message.robot_code
# Debug: 打印完整的 event 数据
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
logger.debug(f"[DingTalk] =====================================")
image_download_handler = self # 传入方法所在的类实例
# Filter out stale messages from before channel startup (offline backlog)
create_at = getattr(incoming_message, 'create_at', None)
if create_at:
msg_age_s = time.time() - int(create_at) / 1000
if msg_age_s > 60:
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
return AckMessage.STATUS_OK, 'OK'
image_download_handler = self
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
if dingtalk_msg.is_group:
@@ -482,8 +598,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
self.handle_single(dingtalk_msg)
return AckMessage.STATUS_OK, 'OK'
except Exception as e:
logger.error(f"[DingTalk] process error: {e}")
logger.exception(e) # 打印完整堆栈跟踪
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
@time_checker
@@ -607,7 +722,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
def send(self, reply: Reply, context: Context):
logger.info(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
receiver = context["receiver"]
# Check if msg exists (for scheduled tasks, msg might be None)
@@ -647,7 +762,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
robot_code = msg.robot_code
if robot_code and robot_code != self._robot_code:
self._robot_code = robot_code
logger.info(f"[DingTalk] Cached robot_code: {robot_code}")
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
isgroup = msg.is_group
incoming_message = msg.incoming_message

View File

@@ -9,6 +9,7 @@ from channel.chat_message import ChatMessage
# -*- coding=utf-8 -*-
from common.log import logger
from common.tmp_dir import TmpDir
from common.utils import expand_path
from config import conf
@@ -49,7 +50,7 @@ class DingTalkMessage(ChatMessage):
download_url = image_download_handler.get_image_download_url(download_code)
# 下载到工作空间 tmp 目录
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
@@ -67,7 +68,7 @@ class DingTalkMessage(ChatMessage):
self.ctype = ContextType.TEXT
# 下载到工作空间 tmp 目录
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)

View File

@@ -140,6 +140,23 @@ python3 app.py
**解决**: 安装依赖 `pip install lark-oapi`
### SSL证书验证失败
```
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
```
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
**解决**: 程序会自动检测SSL证书验证失败并自动重试禁用证书验证的连接。无需手动配置。
当遇到证书错误时,日志会显示:
```
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
```
这是正常现象,程序会自动处理并继续运行。
### Webhook模式端口被占用
```

View File

@@ -11,8 +11,11 @@
@Date 2023/11/19
"""
import importlib.util
import json
import logging
import os
import ssl
import threading
# -*- coding=utf-8 -*-
import uuid
@@ -31,17 +34,25 @@ from common.log import logger
from common.singleton import singleton
from config import conf
# Suppress verbose logs from Lark SDK
logging.getLogger("Lark").setLevel(logging.WARNING)
URL_VERIFICATION = "url_verification"
# 尝试导入飞书SDK,如果未安装则websocket模式不可用
try:
import lark_oapi as lark
# Lazy-check for lark_oapi SDK availability without importing it at module level.
# The full `import lark_oapi` pulls in 10k+ files and takes 4-10s, so we defer
# the actual import to _startup_websocket() where it is needed.
LARK_SDK_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
lark = None # will be populated on first use via _ensure_lark_imported()
LARK_SDK_AVAILABLE = True
except ImportError:
LARK_SDK_AVAILABLE = False
logger.warning(
"[FeiShu] lark_oapi not installed, websocket mode is not available. Install with: pip install lark-oapi")
def _ensure_lark_imported():
"""Import lark_oapi on first use (takes 4-10s due to 10k+ source files)."""
global lark
if lark is None:
import lark_oapi as _lark
lark = _lark
return lark
@singleton
@@ -55,6 +66,10 @@ class FeiShuChanel(ChatChannel):
super().__init__()
# 历史消息id暂存用于幂等控制
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
self._http_server = None
self._ws_client = None
self._ws_thread = None
self._bot_open_id = None # cached bot open_id for @-mention matching
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
# 无需群校验和前缀
@@ -67,11 +82,66 @@ class FeiShuChanel(ChatChannel):
raise Exception("lark_oapi not installed")
def startup(self):
self.feishu_app_id = conf().get('feishu_app_id')
self.feishu_app_secret = conf().get('feishu_app_secret')
self.feishu_token = conf().get('feishu_token')
self.feishu_event_mode = conf().get('feishu_event_mode', 'websocket')
self._fetch_bot_open_id()
if self.feishu_event_mode == 'websocket':
self._startup_websocket()
else:
self._startup_webhook()
def _fetch_bot_open_id(self):
"""Fetch the bot's own open_id via API so we can match @-mentions without feishu_bot_name."""
try:
access_token = self.fetch_access_token()
if not access_token:
logger.warning("[FeiShu] Cannot fetch bot info: no access_token")
return
headers = {"Authorization": "Bearer " + access_token}
resp = requests.get("https://open.feishu.cn/open-apis/bot/v3/info/", headers=headers, timeout=5)
if resp.status_code == 200:
data = resp.json()
if data.get("code") == 0:
self._bot_open_id = data.get("bot", {}).get("open_id")
logger.info(f"[FeiShu] Bot open_id fetched: {self._bot_open_id}")
else:
logger.warning(f"[FeiShu] Fetch bot info failed: code={data.get('code')}, msg={data.get('msg')}")
except Exception as e:
logger.warning(f"[FeiShu] Fetch bot open_id error: {e}")
def stop(self):
import ctypes
logger.info("[FeiShu] stop() called")
ws_client = self._ws_client
self._ws_client = None
ws_thread = self._ws_thread
self._ws_thread = None
# Interrupt the ws thread first so its blocking start() unblocks
if ws_thread and ws_thread.is_alive():
try:
tid = ws_thread.ident
if tid:
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
)
if res == 1:
logger.info("[FeiShu] Interrupted ws thread via ctypes")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
except Exception as e:
logger.warning(f"[FeiShu] Error interrupting ws thread: {e}")
# lark.ws.Client has no stop() method; thread interruption above is sufficient
if self._http_server:
try:
self._http_server.stop()
logger.info("[FeiShu] HTTP server stopped")
except Exception as e:
logger.warning(f"[FeiShu] Error stopping HTTP server: {e}")
self._http_server = None
logger.info("[FeiShu] stop() completed")
def _startup_webhook(self):
"""启动HTTP服务器接收事件(webhook模式)"""
logger.debug("[FeiShu] Starting in webhook mode...")
@@ -80,21 +150,33 @@ class FeiShuChanel(ChatChannel):
)
app = web.application(urls, globals(), autoreload=False)
port = conf().get("feishu_port", 9891)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
func = web.httpserver.StaticMiddleware(app.wsgifunc())
func = web.httpserver.LogMiddleware(func)
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
self._http_server = server
try:
server.start()
except (KeyboardInterrupt, SystemExit):
server.stop()
def _startup_websocket(self):
"""启动长连接接收事件(websocket模式)"""
_ensure_lark_imported()
logger.debug("[FeiShu] Starting in websocket mode...")
# 创建事件处理器
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
"""处理接收消息事件 v2.0"""
try:
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
# 转换为标准的event格式
event_dict = json.loads(lark.JSON.marshal(data))
event = event_dict.get("event", {})
msg = event.get("message", {})
# Skip group messages that don't @-mention the bot (reduce log noise)
if msg.get("chat_type") == "group" and not msg.get("mentions") and msg.get("message_type") == "text":
return
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
# 处理消息
self._handle_message_event(event)
@@ -107,29 +189,99 @@ class FeiShuChanel(ChatChannel):
.register_p2_im_message_receive_v1(handle_message_event) \
.build()
# 创建长连接客户端
ws_client = lark.ws.Client(
self.feishu_app_id,
self.feishu_app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.DEBUG if conf().get("debug") else lark.LogLevel.INFO
)
def start_client_with_retry():
"""Run ws client in this thread with its own event loop to avoid conflicts."""
import asyncio
import ssl as ssl_module
original_create_default_context = ssl_module.create_default_context
# 在新线程中启动客户端,避免阻塞主线程
def start_client():
def create_unverified_context(*args, **kwargs):
context = original_create_default_context(*args, **kwargs)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context
# lark_oapi.ws.client captures the event loop at module-import time as a module-
# level global variable. When a previous ws thread is force-killed via ctypes its
# loop may still be marked as "running", which causes the next ws_client.start()
# call (in this new thread) to raise "This event loop is already running".
# Fix: replace the module-level loop with a brand-new, idle loop before starting.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
logger.debug("[FeiShu] Websocket client starting...")
ws_client.start()
except Exception as e:
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
import lark_oapi.ws.client as _lark_ws_client_mod
_lark_ws_client_mod.loop = loop
except Exception:
pass
ws_thread = threading.Thread(target=start_client, daemon=True)
startup_error = None
for attempt in range(2):
try:
if attempt == 1:
logger.warning("[FeiShu] Retrying with SSL verification disabled...")
ssl_module.create_default_context = create_unverified_context
ssl_module._create_unverified_context = create_unverified_context
ws_client = lark.ws.Client(
self.feishu_app_id,
self.feishu_app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.WARNING
)
self._ws_client = ws_client
logger.debug("[FeiShu] Websocket client starting...")
ws_client.start()
break
except (SystemExit, KeyboardInterrupt):
logger.info("[FeiShu] Websocket thread received stop signal")
break
except Exception as e:
error_msg = str(e)
is_ssl_error = ("CERTIFICATE_VERIFY_FAILED" in error_msg
or "certificate verify failed" in error_msg.lower())
if is_ssl_error and attempt == 0:
logger.warning(f"[FeiShu] SSL error: {error_msg}, retrying...")
continue
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
startup_error = error_msg
ssl_module.create_default_context = original_create_default_context
break
if startup_error:
self.report_startup_error(startup_error)
try:
loop.close()
except Exception:
pass
logger.info("[FeiShu] Websocket thread exited")
ws_thread = threading.Thread(target=start_client_with_retry, daemon=True)
self._ws_thread = ws_thread
ws_thread.start()
# 保持主线程运行
logger.info("[FeiShu] ✅ Websocket connected, ready to receive messages")
logger.info("[FeiShu] ✅ Websocket thread started, ready to receive messages")
ws_thread.join()
def _is_mention_bot(self, mentions: list) -> bool:
"""Check whether any mention in the list refers to this bot.
Priority:
1. Match by open_id (obtained from /bot/v3/info at startup, no config needed)
2. Fallback to feishu_bot_name config for backward compatibility
3. If neither is available, assume the first mention is the bot (Feishu only
delivers group messages that @-mention the bot, so this is usually correct)
"""
if self._bot_open_id:
return any(
m.get("id", {}).get("open_id") == self._bot_open_id
for m in mentions
)
bot_name = conf().get("feishu_bot_name")
if bot_name:
return any(m.get("name") == bot_name for m in mentions)
# Feishu event subscription only delivers messages that @-mention the bot,
# so reaching here means the bot was indeed mentioned.
return True
def _handle_message_event(self, event: dict):
"""
处理消息事件的核心逻辑
@@ -148,6 +300,15 @@ class FeiShuChanel(ChatChannel):
return
self.receivedMsgs[msg_id] = True
# Filter out stale messages from before channel startup (offline backlog)
import time as _time
create_time_ms = msg.get("create_time")
if create_time_ms:
msg_age_s = _time.time() - int(create_time_ms) / 1000
if msg_age_s > 60:
logger.warning(f"[FeiShu] stale msg filtered (age={msg_age_s:.0f}s), msg_id={msg_id}")
return
is_group = False
chat_type = msg.get("chat_type")
@@ -155,10 +316,9 @@ class FeiShuChanel(ChatChannel):
if not msg.get("mentions") and msg.get("message_type") == "text":
# 群聊中未@不响应
return
if msg.get("mentions") and msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get(
"message_type") == "text":
# 不是@机器人,不响应
return
if msg.get("mentions") and msg.get("message_type") == "text":
if not self._is_mention_bot(msg.get("mentions")):
return
# 群聊
is_group = True
receive_id_type = "chat_id"
@@ -176,7 +336,7 @@ class FeiShuChanel(ChatChannel):
# 处理文件缓存逻辑
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
# 获取 session_id用于缓存关联
if is_group:
if conf().get("group_shared_session", True):
@@ -185,7 +345,7 @@ class FeiShuChanel(ChatChannel):
session_id = feishu_msg.from_user_id + "_" + msg.get("chat_id")
else:
session_id = feishu_msg.from_user_id
# 如果是单张图片消息,缓存起来
if feishu_msg.ctype == ContextType.IMAGE:
if hasattr(feishu_msg, 'image_path') and feishu_msg.image_path:
@@ -193,7 +353,7 @@ class FeiShuChanel(ChatChannel):
logger.info(f"[FeiShu] Image cached for session {session_id}, waiting for user query...")
# 单张图片不直接处理,等待用户提问
return
# 如果是文本消息,检查是否有缓存的文件
if feishu_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
@@ -209,7 +369,7 @@ class FeiShuChanel(ChatChannel):
file_refs.append(f"[视频: {file_path}]")
else:
file_refs.append(f"[文件: {file_path}]")
feishu_msg.content = feishu_msg.content + "\n" + "\n".join(file_refs)
logger.info(f"[FeiShu] Attached {len(cached_files)} cached file(s) to user query")
# 清除缓存
@@ -251,28 +411,35 @@ class FeiShuChanel(ChatChannel):
msg_type = "image"
content_key = "image_key"
elif reply.type == ReplyType.FILE:
# 如果有附加的文本内容,先发送文本
if hasattr(reply, 'text_content') and reply.text_content:
logger.info(f"[FeiShu] Sending text before file: {reply.text_content[:50]}...")
text_reply = Reply(ReplyType.TEXT, reply.text_content)
self._send(text_reply, context)
import time
time.sleep(0.3) # 短暂延迟,确保文本先到达
# 判断是否为视频文件
file_path = reply.content
if file_path.startswith("file://"):
file_path = file_path[7:]
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
if is_video:
# 视频使用 media 类型,需要上传并获取 file_key 和 duration
video_info = self._upload_video_url(reply.content, access_token)
if not video_info or not video_info.get('file_key'):
# 视频上传(包含duration信息)
upload_data = self._upload_video_url(reply.content, access_token)
if not upload_data or not upload_data.get('file_key'):
logger.warning("[FeiShu] upload video failed")
return
# media 类型需要特殊的 content 格式
# 视频使用 media 类型(根据官方文档)
# 错误码 230055 说明:上传 mp4 时必须使用 msg_type="media"
msg_type = "media"
# 注意media 类型的 content 不使用 content_key而是完整的 JSON 对象
reply_content = {
"file_key": video_info['file_key'],
"duration": video_info.get('duration', 0) # 视频时长(毫秒)
}
content_key = None # media 类型不使用单一的 key
reply_content = upload_data # 完整的上传响应数据包含file_key和duration
logger.info(
f"[FeiShu] Sending video: file_key={upload_data.get('file_key')}, duration={upload_data.get('duration')}ms")
content_key = None # 直接序列化整个对象
else:
# 其他文件使用 file 类型
file_key = self._upload_file_url(reply.content, access_token)
@@ -282,16 +449,20 @@ class FeiShuChanel(ChatChannel):
reply_content = file_key
msg_type = "file"
content_key = "file_key"
# Check if we can reply to an existing message (need msg_id)
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
# Build content JSON
content_json = json.dumps(reply_content, ensure_ascii=False) if content_key is None else json.dumps({content_key: reply_content}, ensure_ascii=False)
logger.debug(f"[FeiShu] Sending message: msg_type={msg_type}, content={content_json[:200]}")
if can_reply:
# 群聊中回复已有消息
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
data = {
"msg_type": msg_type,
"content": json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
"content": content_json
}
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
else:
@@ -301,7 +472,7 @@ class FeiShuChanel(ChatChannel):
data = {
"receive_id": context.get("receiver"),
"msg_type": msg_type,
"content": json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
"content": content_json
}
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
res = res.json()
@@ -310,7 +481,6 @@ class FeiShuChanel(ChatChannel):
else:
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
def fetch_access_token(self) -> str:
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
headers = {
@@ -332,35 +502,34 @@ class FeiShuChanel(ChatChannel):
else:
logger.error(f"[FeiShu] fetch token error, res={response}")
def _upload_image_url(self, img_url, access_token):
logger.debug(f"[FeiShu] start process image, img_url={img_url}")
# Check if it's a local file path (file:// protocol)
if img_url.startswith("file://"):
local_path = img_url[7:] # Remove "file://" prefix
logger.info(f"[FeiShu] uploading local file: {local_path}")
if not os.path.exists(local_path):
logger.error(f"[FeiShu] local file not found: {local_path}")
return None
# Upload directly from local file
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
data = {'image_type': 'message'}
headers = {'Authorization': f'Bearer {access_token}'}
with open(local_path, "rb") as file:
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
return response_data.get("data").get("image_key")
else:
logger.error(f"[FeiShu] upload failed: {response_data}")
return None
# Original logic for HTTP URLs
response = requests.get(img_url)
suffix = utils.get_path_suffix(img_url)
@@ -396,7 +565,7 @@ class FeiShuChanel(ChatChannel):
"""
try:
import subprocess
# 使用 ffprobe 获取视频时长
cmd = [
'ffprobe',
@@ -405,7 +574,7 @@ class FeiShuChanel(ChatChannel):
'-of', 'default=noprint_wrappers=1:nokey=1',
file_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0:
duration_seconds = float(result.stdout.strip())
@@ -434,7 +603,7 @@ class FeiShuChanel(ChatChannel):
"""
local_path = None
temp_file = None
try:
# For file:// URLs (local files), upload directly
if video_url.startswith("file://"):
@@ -449,56 +618,67 @@ class FeiShuChanel(ChatChannel):
if response.status_code != 200:
logger.error(f"[FeiShu] download video failed, status={response.status_code}")
return None
# Save to temp file
import uuid
file_name = os.path.basename(video_url) or "video.mp4"
temp_file = str(uuid.uuid4()) + "_" + file_name
with open(temp_file, "wb") as file:
file.write(response.content)
logger.info(f"[FeiShu] Video downloaded, size={len(response.content)} bytes")
local_path = temp_file
# Get video duration
duration = self._get_video_duration(local_path)
# Upload to Feishu
file_name = os.path.basename(local_path)
file_ext = os.path.splitext(file_name)[1].lower()
file_type_map = {'.mp4': 'mp4'}
file_type = file_type_map.get(file_ext, 'mp4')
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {'file_type': file_type, 'file_name': file_name}
data = {
'file_type': file_type,
'file_name': file_name
}
# Add duration only if available (required for video/audio)
if duration:
data['duration'] = duration # Must be int, not string
headers = {'Authorization': f'Bearer {access_token}'}
logger.info(f"[FeiShu] Uploading video: file_name={file_name}, duration={duration}ms")
with open(local_path, "rb") as file:
upload_response = requests.post(
upload_url,
files={"file": file},
data=data,
headers=headers,
upload_url,
files={"file": file},
data=data,
headers=headers,
timeout=(5, 60)
)
logger.info(f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
logger.info(
f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
file_key = response_data.get("data").get("file_key")
return {
'file_key': file_key,
'duration': duration
}
# Add duration to the response data (API doesn't return it)
upload_data = response_data.get("data")
upload_data['duration'] = duration # Add our calculated duration
logger.info(
f"[FeiShu] Upload complete: file_key={upload_data.get('file_key')}, duration={duration}ms")
return upload_data
else:
logger.error(f"[FeiShu] upload video failed: {response_data}")
return None
except Exception as e:
logger.error(f"[FeiShu] upload video exception: {e}")
return None
finally:
# Clean up temp file
if temp_file and os.path.exists(temp_file):
@@ -513,20 +693,20 @@ class FeiShuChanel(ChatChannel):
Supports both local files (file://) and HTTP URLs
"""
logger.debug(f"[FeiShu] start process file, file_url={file_url}")
# Check if it's a local file path (file:// protocol)
if file_url.startswith("file://"):
local_path = file_url[7:] # Remove "file://" prefix
logger.info(f"[FeiShu] uploading local file: {local_path}")
if not os.path.exists(local_path):
logger.error(f"[FeiShu] local file not found: {local_path}")
return None
# Get file info
file_name = os.path.basename(local_path)
file_ext = os.path.splitext(file_name)[1].lower()
# Determine file type for Feishu API
# Feishu supports: opus, mp4, pdf, doc, xls, ppt, stream (other types)
file_type_map = {
@@ -538,23 +718,24 @@ class FeiShuChanel(ChatChannel):
'.ppt': 'ppt', '.pptx': 'ppt',
}
file_type = file_type_map.get(file_ext, 'stream') # Default to stream for other types
# Upload file to Feishu
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {'file_type': file_type, 'file_name': file_name}
headers = {'Authorization': f'Bearer {access_token}'}
try:
with open(local_path, "rb") as file:
upload_response = requests.post(
upload_url,
files={"file": file},
data=data,
upload_url,
files={"file": file},
data=data,
headers=headers,
timeout=(5, 30) # 5s connect, 30s read timeout
)
logger.info(f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
logger.info(
f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
return response_data.get("data").get("file_key")
@@ -564,22 +745,22 @@ class FeiShuChanel(ChatChannel):
except Exception as e:
logger.error(f"[FeiShu] upload file exception: {e}")
return None
# For HTTP URLs, download first then upload
try:
response = requests.get(file_url, timeout=(5, 30))
if response.status_code != 200:
logger.error(f"[FeiShu] download file failed, status={response.status_code}")
return None
# Save to temp file
import uuid
file_name = os.path.basename(file_url)
temp_name = str(uuid.uuid4()) + "_" + file_name
with open(temp_name, "wb") as file:
file.write(response.content)
# Upload
file_ext = os.path.splitext(file_name)[1].lower()
file_type_map = {
@@ -589,18 +770,18 @@ class FeiShuChanel(ChatChannel):
'.ppt': 'ppt', '.pptx': 'ppt',
}
file_type = file_type_map.get(file_ext, 'stream')
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {'file_type': file_type, 'file_name': file_name}
headers = {'Authorization': f'Bearer {access_token}'}
with open(temp_name, "rb") as file:
upload_response = requests.post(upload_url, files={"file": file}, data=data, headers=headers)
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
response_data = upload_response.json()
os.remove(temp_name) # Clean up temp file
if response_data.get("code") == 0:
return response_data.get("data").get("file_key")
else:
@@ -613,11 +794,13 @@ class FeiShuChanel(ChatChannel):
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
# Set session_id based on chat type
if cmsg.is_group:
# Group chat: check if group_shared_session is enabled
@@ -633,7 +816,7 @@ class FeiShuChanel(ChatChannel):
else:
# Private chat: use user_id only
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:

View File

@@ -6,6 +6,7 @@ import requests
from common.log import logger
from common.tmp_dir import TmpDir
from common import utils
from common.utils import expand_path
from config import conf
@@ -31,7 +32,7 @@ class FeishuMessage(ChatMessage):
image_key = content.get("image_key")
# 下载图片到工作空间临时目录
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
image_path = os.path.join(tmp_dir, f"{image_key}.png")
@@ -97,7 +98,7 @@ class FeishuMessage(ChatMessage):
if image_keys:
# 如果包含图片,下载并在文本中引用本地路径
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)

0
channel/qq/__init__.py Normal file
View File

736
channel/qq/qq_channel.py Normal file
View File

@@ -0,0 +1,736 @@
"""
QQ Bot channel via WebSocket long connection.
Supports:
- Group chat (@bot), single chat (C2C), guild channel, guild DM
- Text / image / file message send & receive
- Heartbeat keep-alive and auto-reconnect with session resume
"""
import base64
import json
import os
import threading
import time
import requests
import websocket
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.qq.qq_message import QQMessage
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.ws_client_compat import websocket_app_run_forever
from config import conf
# Rich media file_type constants
QQ_FILE_TYPE_IMAGE = 1
QQ_FILE_TYPE_VIDEO = 2
QQ_FILE_TYPE_VOICE = 3
QQ_FILE_TYPE_FILE = 4
QQ_API_BASE = "https://api.sgroup.qq.com"
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
# OpCode constants
OP_DISPATCH = 0
OP_HEARTBEAT = 1
OP_IDENTIFY = 2
OP_RESUME = 6
OP_RECONNECT = 7
OP_INVALID_SESSION = 9
OP_HELLO = 10
OP_HEARTBEAT_ACK = 11
# Resumable error codes
RESUMABLE_CLOSE_CODES = {4008, 4009}
@singleton
class QQChannel(ChatChannel):
def __init__(self):
super().__init__()
self.app_id = ""
self.app_secret = ""
self._access_token = ""
self._token_expires_at = 0
self._ws = None
self._ws_thread = None
self._heartbeat_thread = None
self._connected = False
self._stop_event = threading.Event()
self._token_lock = threading.Lock()
self._session_id = None
self._last_seq = None
self._heartbeat_interval = 45000
self._can_resume = False
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
self._msg_seq_counter = {}
conf()["group_name_white_list"] = ["ALL_GROUP"]
conf()["single_chat_prefix"] = [""]
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def startup(self):
self.app_id = conf().get("qq_app_id", "")
self.app_secret = conf().get("qq_app_secret", "")
if not self.app_id or not self.app_secret:
err = "[QQ] qq_app_id and qq_app_secret are required"
logger.error(err)
self.report_startup_error(err)
return
self._refresh_access_token()
if not self._access_token:
err = "[QQ] Failed to get initial access_token"
logger.error(err)
self.report_startup_error(err)
return
self._stop_event.clear()
self._start_ws()
def stop(self):
logger.info("[QQ] stop() called")
self._stop_event.set()
if self._ws:
try:
self._ws.close()
except Exception:
pass
self._ws = None
self._connected = False
# ------------------------------------------------------------------
# Access Token
# ------------------------------------------------------------------
def _refresh_access_token(self):
try:
resp = requests.post(
"https://bots.qq.com/app/getAppAccessToken",
json={"appId": self.app_id, "clientSecret": self.app_secret},
timeout=10,
)
resp.raise_for_status()
data = resp.json()
self._access_token = data.get("access_token", "")
expires_in = int(data.get("expires_in", 7200))
self._token_expires_at = time.time() + expires_in - 60
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
except Exception as e:
logger.error(f"[QQ] Failed to refresh access_token: {e}")
def _get_access_token(self) -> str:
with self._token_lock:
if time.time() >= self._token_expires_at:
self._refresh_access_token()
return self._access_token
def _get_auth_headers(self) -> dict:
return {
"Authorization": f"QQBot {self._get_access_token()}",
"Content-Type": "application/json",
}
# ------------------------------------------------------------------
# WebSocket connection
# ------------------------------------------------------------------
def _get_ws_url(self) -> str:
try:
resp = requests.get(
f"{QQ_API_BASE}/gateway",
headers=self._get_auth_headers(),
timeout=10,
)
resp.raise_for_status()
url = resp.json().get("url", "")
logger.debug(f"[QQ] Gateway URL: {url}")
return url
except Exception as e:
logger.error(f"[QQ] Failed to get gateway URL: {e}")
return ""
def _start_ws(self):
ws_url = self._get_ws_url()
if not ws_url:
logger.error("[QQ] Cannot start WebSocket without gateway URL")
self.report_startup_error("Failed to get gateway URL")
return
def _on_open(ws):
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
def _on_message(ws, raw):
try:
data = json.loads(raw)
self._handle_ws_message(data)
except Exception as e:
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
def _on_error(ws, error):
logger.error(f"[QQ] WebSocket error: {error}")
def _on_close(ws, close_status_code, close_msg):
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
self._connected = False
if not self._stop_event.is_set():
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
self._can_resume = True
logger.info("[QQ] Will attempt resume in 3s...")
time.sleep(3)
else:
self._can_resume = False
logger.info("[QQ] Will reconnect in 5s...")
time.sleep(5)
if not self._stop_event.is_set():
self._start_ws()
self._ws = websocket.WebSocketApp(
ws_url,
on_open=_on_open,
on_message=_on_message,
on_error=_on_error,
on_close=_on_close,
)
def run_forever():
try:
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
except (SystemExit, KeyboardInterrupt):
logger.info("[QQ] WebSocket thread interrupted")
except Exception as e:
logger.error(f"[QQ] WebSocket run_forever error: {e}")
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
self._ws_thread.start()
self._ws_thread.join()
def _ws_send(self, data: dict):
if self._ws:
self._ws.send(json.dumps(data, ensure_ascii=False))
# ------------------------------------------------------------------
# Identify & Resume & Heartbeat
# ------------------------------------------------------------------
def _send_identify(self):
self._ws_send({
"op": OP_IDENTIFY,
"d": {
"token": f"QQBot {self._get_access_token()}",
"intents": DEFAULT_INTENTS,
"shard": [0, 1],
"properties": {
"$os": "linux",
"$browser": "chatgpt-on-wechat",
"$device": "chatgpt-on-wechat",
},
},
})
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
def _send_resume(self):
self._ws_send({
"op": OP_RESUME,
"d": {
"token": f"QQBot {self._get_access_token()}",
"session_id": self._session_id,
"seq": self._last_seq,
},
})
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
def _start_heartbeat(self, interval_ms: int):
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
return
self._heartbeat_interval = interval_ms
interval_sec = interval_ms / 1000.0
def heartbeat_loop():
while not self._stop_event.is_set() and self._connected:
try:
self._ws_send({
"op": OP_HEARTBEAT,
"d": self._last_seq,
})
except Exception as e:
logger.warning(f"[QQ] Heartbeat send failed: {e}")
break
self._stop_event.wait(interval_sec)
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
self._heartbeat_thread.start()
# ------------------------------------------------------------------
# Incoming message dispatch
# ------------------------------------------------------------------
def _handle_ws_message(self, data: dict):
op = data.get("op")
d = data.get("d")
t = data.get("t")
s = data.get("s")
if s is not None:
self._last_seq = s
if op == OP_HELLO:
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
self._heartbeat_interval = heartbeat_interval
if self._can_resume and self._session_id:
self._send_resume()
else:
self._send_identify()
elif op == OP_HEARTBEAT_ACK:
pass
elif op == OP_HEARTBEAT:
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
elif op == OP_RECONNECT:
logger.warning("[QQ] Server requested reconnect")
self._can_resume = True
if self._ws:
self._ws.close()
elif op == OP_INVALID_SESSION:
logger.warning("[QQ] Invalid session, re-identifying...")
self._session_id = None
self._can_resume = False
time.sleep(2)
self._send_identify()
elif op == OP_DISPATCH:
if t == "READY":
self._session_id = d.get("session_id", "")
user = d.get("user", {})
bot_name = user.get('username', '')
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
self._connected = True
self._can_resume = False
self._start_heartbeat(self._heartbeat_interval)
self.report_startup_success()
elif t == "RESUMED":
logger.info("[QQ] Session resumed successfully")
self._connected = True
self._can_resume = False
self._start_heartbeat(self._heartbeat_interval)
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
self._handle_msg_event(d, t)
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
logger.info(f"[QQ] Event: {t}")
else:
logger.debug(f"[QQ] Dispatch event: {t}")
# ------------------------------------------------------------------
# Message event handling
# ------------------------------------------------------------------
def _handle_msg_event(self, event_data: dict, event_type: str):
msg_id = event_data.get("id", "")
if self.received_msgs.get(msg_id):
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
return
self.received_msgs[msg_id] = True
try:
qq_msg = QQMessage(event_data, event_type)
except NotImplementedError as e:
logger.warning(f"[QQ] {e}")
return
except Exception as e:
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
return
is_group = qq_msg.is_group
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
if is_group:
session_id = qq_msg.other_user_id
else:
session_id = qq_msg.from_user_id
if qq_msg.ctype == ContextType.IMAGE:
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
file_cache.add(session_id, qq_msg.image_path, file_type="image")
logger.info(f"[QQ] Image cached for session {session_id}")
return
if qq_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
file_refs = []
for fi in cached_files:
ftype = fi["type"]
fpath = fi["path"]
if ftype == "image":
file_refs.append(f"[图片: {fpath}]")
elif ftype == "video":
file_refs.append(f"[视频: {fpath}]")
else:
file_refs.append(f"[文件: {fpath}]")
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
file_cache.clear(session_id)
context = self._compose_context(
qq_msg.ctype,
qq_msg.content,
isgroup=is_group,
msg=qq_msg,
no_need_at=True,
)
if context:
self.produce(context)
# ------------------------------------------------------------------
# _compose_context
# ------------------------------------------------------------------
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
if cmsg.is_group:
context["session_id"] = cmsg.other_user_id
else:
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
return context
# ------------------------------------------------------------------
# Send reply
# ------------------------------------------------------------------
def send(self, reply: Reply, context: Context):
msg = context.get("msg")
is_group = context.get("isgroup", False)
receiver = context.get("receiver", "")
if not msg:
# Active send (e.g. scheduled tasks), no original message to reply to
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
receiver, is_group)
return
event_type = getattr(msg, "event_type", "")
msg_id = getattr(msg, "msg_id", "")
if reply.type == ReplyType.TEXT:
self._send_text(reply.content, msg, event_type, msg_id)
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
self._send_image(reply.content, msg, event_type, msg_id)
elif reply.type == ReplyType.FILE:
if hasattr(reply, "text_content") and reply.text_content:
self._send_text(reply.text_content, msg, event_type, msg_id)
time.sleep(0.3)
self._send_file(reply.content, msg, event_type, msg_id)
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
else:
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
self._send_text(str(reply.content), msg, event_type, msg_id)
# ------------------------------------------------------------------
# Send helpers
# ------------------------------------------------------------------
def _get_next_msg_seq(self, msg_id: str) -> int:
seq = self._msg_seq_counter.get(msg_id, 1)
self._msg_seq_counter[msg_id] = seq + 1
return seq
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
"""Build the API URL and base body dict for sending a message."""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
body = {
"msg_id": msg_id,
"msg_seq": self._get_next_msg_seq(msg_id),
}
return url, body, "group", group_openid
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
body = {
"msg_id": msg_id,
"msg_seq": self._get_next_msg_seq(msg_id),
}
return url, body, "c2c", user_openid
elif event_type == "AT_MESSAGE_CREATE":
channel_id = msg._rawmsg.get("channel_id", "")
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
body = {"msg_id": msg_id}
return url, body, "channel", channel_id
elif event_type == "DIRECT_MESSAGE_CREATE":
guild_id = msg._rawmsg.get("guild_id", "")
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
body = {"msg_id": msg_id}
return url, body, "dm", guild_id
return None, None, None, None
def _post_message(self, url: str, body: dict, event_type: str):
try:
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
if resp.status_code in (200, 201, 202, 204):
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
else:
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
f"body={resp.text}")
except Exception as e:
logger.error(f"[QQ] Send message error: {e}")
# ------------------------------------------------------------------
# Active send (no original message, e.g. scheduled tasks)
# ------------------------------------------------------------------
def _active_send_text(self, content: str, receiver: str, is_group: bool):
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
if not receiver:
logger.warning("[QQ] No receiver for active send")
return
if is_group:
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
else:
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
body = {
"content": content,
"msg_type": 0,
}
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
self._post_message(url, body, event_label)
# ------------------------------------------------------------------
# Send text
# ------------------------------------------------------------------
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
if not url:
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
return
body["content"] = content
body["msg_type"] = 0
self._post_message(url, body, event_type)
# ------------------------------------------------------------------
# Rich media upload & send (image / video / file)
# ------------------------------------------------------------------
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
event_type: str) -> str:
"""
Upload media via QQ rich media API and return file_info.
For group: POST /v2/groups/{group_openid}/files
For c2c: POST /v2/users/{openid}/files
"""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
or msg.from_user_id)
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
else:
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
return ""
upload_body = {
"file_type": file_type,
"url": file_url,
"srv_send_msg": False,
}
try:
resp = requests.post(
upload_url, json=upload_body,
headers=self._get_auth_headers(), timeout=30,
)
if resp.status_code in (200, 201):
data = resp.json()
file_info = data.get("file_info", "")
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
f"file_uuid={data.get('file_uuid', '')}")
return file_info
else:
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
f"body={resp.text}")
return ""
except Exception as e:
logger.error(f"[QQ] Rich media upload error: {e}")
return ""
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
event_type: str) -> str:
"""Upload local file via base64 file_data field."""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
or msg.from_user_id)
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
else:
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
return ""
try:
with open(file_path, "rb") as f:
file_data = base64.b64encode(f.read()).decode("utf-8")
except Exception as e:
logger.error(f"[QQ] Failed to read file for upload: {e}")
return ""
upload_body = {
"file_type": file_type,
"file_data": file_data,
"srv_send_msg": False,
}
try:
resp = requests.post(
upload_url, json=upload_body,
headers=self._get_auth_headers(), timeout=30,
)
if resp.status_code in (200, 201):
data = resp.json()
file_info = data.get("file_info", "")
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
f"file_uuid={data.get('file_uuid', '')}")
return file_info
else:
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
f"body={resp.text}")
return ""
except Exception as e:
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
return ""
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send a message with msg_type=7 (rich media) using file_info."""
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
if not url:
return
body["msg_type"] = 7
body["media"] = {"file_info": file_info}
self._post_message(url, body, event_type)
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send image reply. Supports URL and local file path."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
return
if img_path_or_url.startswith("file://"):
img_path_or_url = img_path_or_url[7:]
if img_path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
elif os.path.exists(img_path_or_url):
file_info = self._upload_rich_media_base64(
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
else:
logger.error(f"[QQ] Image not found: {img_path_or_url}")
self._send_text("[Image send failed]", msg, event_type, msg_id)
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
self._send_text("[Image upload failed]", msg, event_type, msg_id)
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send file reply."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
return
if file_path_or_url.startswith("file://"):
file_path_or_url = file_path_or_url[7:]
if file_path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
elif os.path.exists(file_path_or_url):
file_info = self._upload_rich_media_base64(
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
else:
logger.error(f"[QQ] File not found: {file_path_or_url}")
self._send_text("[File send failed]", msg, event_type, msg_id)
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
self._send_text("[File upload failed]", msg, event_type, msg_id)
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
msg_id: str, file_type: int):
"""Generic media send for video/voice etc."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(path_or_url), msg, event_type, msg_id)
return
if path_or_url.startswith("file://"):
path_or_url = path_or_url[7:]
if path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
elif os.path.exists(path_or_url):
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
else:
logger.error(f"[QQ] Media not found: {path_or_url}")
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
logger.error(f"[QQ] Media upload failed: {path_or_url}")

123
channel/qq/qq_message.py Normal file
View File

@@ -0,0 +1,123 @@
import os
import requests
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.utils import expand_path
from config import conf
def _get_tmp_dir() -> str:
"""Return the workspace tmp directory (absolute path), creating it if needed."""
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(ws_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
class QQMessage(ChatMessage):
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
def __init__(self, event_data: dict, event_type: str):
super().__init__(event_data)
self.msg_id = event_data.get("id", "")
self.create_time = event_data.get("timestamp", "")
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
self.event_type = event_type
author = event_data.get("author", {})
from_user_id = author.get("member_openid", "") or author.get("id", "")
group_openid = event_data.get("group_openid", "")
content = event_data.get("content", "").strip()
attachments = event_data.get("attachments", [])
has_image = any(
a.get("content_type", "").startswith("image/") for a in attachments
) if attachments else False
if has_image and not content:
self.ctype = ContextType.IMAGE
img_attachment = next(
a for a in attachments if a.get("content_type", "").startswith("image/")
)
img_url = img_attachment.get("url", "")
if img_url and not img_url.startswith("http"):
img_url = "https://" + img_url
tmp_dir = _get_tmp_dir()
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
try:
resp = requests.get(img_url, timeout=30)
resp.raise_for_status()
with open(image_path, "wb") as f:
f.write(resp.content)
self.content = image_path
self.image_path = image_path
logger.info(f"[QQ] Image downloaded: {image_path}")
except Exception as e:
logger.error(f"[QQ] Failed to download image: {e}")
self.content = "[Image download failed]"
self.image_path = None
elif has_image and content:
self.ctype = ContextType.TEXT
image_paths = []
tmp_dir = _get_tmp_dir()
for idx, att in enumerate(attachments):
if not att.get("content_type", "").startswith("image/"):
continue
img_url = att.get("url", "")
if img_url and not img_url.startswith("http"):
img_url = "https://" + img_url
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
try:
resp = requests.get(img_url, timeout=30)
resp.raise_for_status()
with open(img_path, "wb") as f:
f.write(resp.content)
image_paths.append(img_path)
except Exception as e:
logger.error(f"[QQ] Failed to download mixed image: {e}")
content_parts = [content]
for p in image_paths:
content_parts.append(f"[图片: {p}]")
self.content = "\n".join(content_parts)
else:
self.ctype = ContextType.TEXT
self.content = content
if event_type == "GROUP_AT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
self.other_user_id = group_openid
self.actual_user_id = from_user_id
self.actual_user_nickname = from_user_id
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = author.get("user_openid", "") or from_user_id
self.from_user_id = user_openid
self.to_user_id = ""
self.other_user_id = user_openid
self.actual_user_id = user_openid
elif event_type == "AT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
channel_id = event_data.get("channel_id", "")
self.other_user_id = channel_id
self.actual_user_id = from_user_id
self.actual_user_nickname = author.get("username", from_user_id)
elif event_type == "DIRECT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
guild_id = event_data.get("guild_id", "")
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
self.actual_user_id = from_user_id
self.actual_user_nickname = author.get("username", from_user_id)
else:
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
f"from={self.from_user_id}, content_len={len(self.content)}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,537 @@
/* =====================================================================
CowAgent Console Styles
===================================================================== */
/* Animations */
@keyframes pulseDot {
0%, 80%, 100% { transform: scale(0.6); opacity: 0.4; }
40% { transform: scale(1); opacity: 1; }
}
/* Scrollbar */
* { scrollbar-width: thin; scrollbar-color: #94a3b8 transparent; }
::-webkit-scrollbar { width: 6px; height: 6px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #94a3b8; border-radius: 3px; }
::-webkit-scrollbar-thumb:hover { background: #64748b; }
.dark ::-webkit-scrollbar-thumb { background: #475569; }
.dark ::-webkit-scrollbar-thumb:hover { background: #64748b; }
/* Sidebar */
.sidebar-item.active {
background: rgba(255, 255, 255, 0.08);
color: #FFFFFF;
}
.sidebar-item.active .item-icon { color: #4ABE6E; }
/* Menu Groups */
.menu-group-items { max-height: 0; overflow: hidden; transition: max-height 0.25s ease-out; }
.menu-group.open .menu-group-items { max-height: 500px; transition: max-height 0.35s ease-in; }
.menu-group .chevron { transition: transform 0.25s ease; }
.menu-group.open .chevron { transform: rotate(90deg); }
/* View Switching */
.view { display: none; height: 100%; }
.view.active { display: flex; flex-direction: column; }
/* Markdown Content */
.msg-content p { margin: 0.5em 0; line-height: 1.7; }
.msg-content p:first-child { margin-top: 0; }
.msg-content p:last-child { margin-bottom: 0; }
.msg-content h1, .msg-content h2, .msg-content h3,
.msg-content h4, .msg-content h5, .msg-content h6 {
margin-top: 1.2em; margin-bottom: 0.6em; font-weight: 600; line-height: 1.3;
}
.msg-content h1 { font-size: 1.4em; }
.msg-content h2 { font-size: 1.25em; }
.msg-content h3 { font-size: 1.1em; }
.msg-content ul, .msg-content ol { margin: 0.5em 0; padding-left: 1.8em; }
.msg-content li { margin: 0.25em 0; }
.msg-content pre {
border-radius: 8px; overflow-x: auto; margin: 0.8em 0;
background: #f1f5f9; padding: 1em;
}
.dark .msg-content pre { background: #111111; }
.msg-content code {
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
font-size: 0.875em;
}
.msg-content :not(pre) > code {
background: rgba(74, 190, 110, 0.1); color: #1C6B3B;
padding: 2px 6px; border-radius: 4px;
}
.dark .msg-content :not(pre) > code {
background: rgba(74, 190, 110, 0.15); color: #74E9A4;
}
.msg-content pre code { background: transparent; padding: 0; color: inherit; }
.msg-content blockquote {
border-left: 3px solid #4ABE6E; padding: 0.5em 1em;
margin: 0.8em 0; background: rgba(74, 190, 110, 0.05); border-radius: 0 6px 6px 0;
}
.dark .msg-content blockquote { background: rgba(74, 190, 110, 0.08); }
.msg-content table { border-collapse: collapse; width: 100%; margin: 0.8em 0; }
.msg-content th, .msg-content td {
border: 1px solid #e2e8f0; padding: 8px 12px; text-align: left;
}
.dark .msg-content th, .dark .msg-content td { border-color: rgba(255,255,255,0.1); }
.msg-content th { background: #f1f5f9; font-weight: 600; }
.dark .msg-content th { background: #111111; }
.msg-content img { max-width: 100%; height: auto; border-radius: 8px; margin: 0.5em 0; }
.msg-content a { color: #35A85B; text-decoration: underline; }
.msg-content a:hover { color: #228547; }
/* Overrides for user bubble (white text on green bg) */
.user-bubble.msg-content a { color: #ffffff !important; text-decoration: underline; text-decoration-color: rgba(255,255,255,0.6); }
.user-bubble.msg-content a:hover { color: #e0f5e8 !important; text-decoration-color: #e0f5e8; }
.user-bubble.msg-content :not(pre) > code { background: rgba(255,255,255,0.2); color: #ffffff; }
.msg-content hr { border: none; height: 1px; background: #e2e8f0; margin: 1.2em 0; }
.dark .msg-content hr { background: rgba(255,255,255,0.1); }
/* SSE Streaming cursor */
@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
.sse-streaming::after {
content: '▋';
display: inline-block;
margin-left: 2px;
color: #4ABE6E;
animation: blink 0.9s step-end infinite;
font-size: 0.85em;
vertical-align: middle;
}
/* Agent steps (thinking summaries + tool indicators) */
.agent-steps:empty { display: none; }
.agent-steps:not(:empty) {
margin-bottom: 0.625rem;
padding-bottom: 0.5rem;
border-bottom: 1px dashed rgba(0, 0, 0, 0.08);
}
.dark .agent-steps:not(:empty) { border-bottom-color: rgba(255, 255, 255, 0.08); }
.agent-step {
font-size: 0.75rem;
line-height: 1.4;
color: #94a3b8;
margin-bottom: 0.25rem;
}
.agent-step:last-child { margin-bottom: 0; }
/* Thinking step - collapsible */
.agent-thinking-step .thinking-header {
display: flex;
align-items: center;
gap: 0.375rem;
cursor: pointer;
user-select: none;
}
.agent-thinking-step .thinking-header.no-toggle { cursor: default; }
.agent-thinking-step .thinking-header:not(.no-toggle):hover { color: #64748b; }
.dark .agent-thinking-step .thinking-header:not(.no-toggle):hover { color: #cbd5e1; }
.agent-thinking-step .thinking-header i:first-child { font-size: 0.625rem; margin-top: 1px; }
.agent-thinking-step .thinking-chevron {
font-size: 0.5rem;
margin-left: auto;
transition: transform 0.2s ease;
opacity: 0.5;
}
.agent-thinking-step.expanded .thinking-chevron { transform: rotate(90deg); }
.agent-thinking-step .thinking-full {
display: none;
margin-top: 0.375rem;
margin-left: 1rem;
padding: 0.5rem;
background: rgba(0, 0, 0, 0.02);
border-radius: 6px;
border: 1px solid rgba(0, 0, 0, 0.04);
font-size: 0.75rem;
line-height: 1.5;
color: #94a3b8;
max-height: 200px;
overflow-y: auto;
}
.dark .agent-thinking-step .thinking-full {
background: rgba(255, 255, 255, 0.02);
border-color: rgba(255, 255, 255, 0.04);
}
.agent-thinking-step.expanded .thinking-full { display: block; }
.agent-thinking-step .thinking-full p { margin: 0.25em 0; }
.agent-thinking-step .thinking-full p:first-child { margin-top: 0; }
.agent-thinking-step .thinking-full p:last-child { margin-bottom: 0; }
/* Tool step - collapsible */
.agent-tool-step .tool-header {
display: flex;
align-items: center;
gap: 0.375rem;
cursor: pointer;
user-select: none;
padding: 1px 0;
border-radius: 4px;
}
.agent-tool-step .tool-header:hover { color: #64748b; }
.dark .agent-tool-step .tool-header:hover { color: #cbd5e1; }
.agent-tool-step .tool-icon { font-size: 0.625rem; }
.agent-tool-step .tool-chevron {
font-size: 0.5rem;
margin-left: auto;
transition: transform 0.2s ease;
opacity: 0.5;
}
.agent-tool-step.expanded .tool-chevron { transform: rotate(90deg); }
.agent-tool-step .tool-time {
font-size: 0.65rem;
opacity: 0.6;
margin-left: 0.25rem;
}
/* Tool detail panel */
.agent-tool-step .tool-detail {
display: none;
margin-top: 0.375rem;
margin-left: 1rem;
padding: 0.5rem;
background: rgba(0, 0, 0, 0.02);
border-radius: 6px;
border: 1px solid rgba(0, 0, 0, 0.04);
}
.dark .agent-tool-step .tool-detail {
background: rgba(255, 255, 255, 0.02);
border-color: rgba(255, 255, 255, 0.04);
}
.agent-tool-step.expanded .tool-detail { display: block; }
.tool-detail-section { margin-bottom: 0.375rem; }
.tool-detail-section:last-child { margin-bottom: 0; }
.tool-detail-label {
font-size: 0.625rem;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.05em;
opacity: 0.6;
margin-bottom: 0.125rem;
}
.tool-detail-content {
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
font-size: 0.7rem;
line-height: 1.5;
white-space: pre-wrap;
word-break: break-all;
max-height: 200px;
overflow-y: auto;
margin: 0;
padding: 0.25rem 0;
background: transparent;
color: inherit;
}
.tool-error-text { color: #f87171; }
/* Tool failed state */
.agent-tool-step.tool-failed .tool-name { color: #f87171; }
/* Config form controls */
#view-config input[type="text"],
#view-config input[type="number"],
#view-config input[type="password"] {
height: 40px;
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
#view-config input:focus {
border-color: #4ABE6E;
box-shadow: 0 0 0 3px rgba(74, 190, 110, 0.12);
}
#view-config input[type="text"]:hover,
#view-config input[type="number"]:hover,
#view-config input[type="password"]:hover {
border-color: #94a3b8;
}
.dark #view-config input[type="text"]:hover,
.dark #view-config input[type="number"]:hover,
.dark #view-config input[type="password"]:hover {
border-color: #64748b;
}
/* Custom dropdown */
.cfg-dropdown {
position: relative;
outline: none;
}
.cfg-dropdown-selected {
display: flex;
align-items: center;
justify-content: space-between;
height: 40px;
padding: 0 0.75rem;
border-radius: 0.5rem;
border: 1px solid #e2e8f0;
background: #f8fafc;
font-size: 0.875rem;
color: #1e293b;
cursor: pointer;
transition: border-color 0.2s ease, box-shadow 0.2s ease;
user-select: none;
}
.dark .cfg-dropdown-selected {
border-color: #475569;
background: rgba(255, 255, 255, 0.05);
color: #f1f5f9;
}
.cfg-dropdown-selected:hover { border-color: #94a3b8; }
.dark .cfg-dropdown-selected:hover { border-color: #64748b; }
.cfg-dropdown.open .cfg-dropdown-selected,
.cfg-dropdown:focus .cfg-dropdown-selected {
border-color: #4ABE6E;
box-shadow: 0 0 0 3px rgba(74, 190, 110, 0.12);
}
.cfg-dropdown-arrow {
font-size: 0.625rem;
color: #94a3b8;
transition: transform 0.2s ease;
flex-shrink: 0;
margin-left: 0.5rem;
}
.cfg-dropdown.open .cfg-dropdown-arrow { transform: rotate(180deg); }
.cfg-dropdown-menu {
display: none;
position: absolute;
top: calc(100% + 4px);
left: 0;
right: 0;
z-index: 50;
max-height: 240px;
overflow-y: auto;
border-radius: 0.5rem;
border: 1px solid #e2e8f0;
background: #ffffff;
box-shadow: 0 10px 25px -5px rgba(0, 0, 0, 0.1), 0 4px 10px -5px rgba(0, 0, 0, 0.04);
padding: 4px;
}
.dark .cfg-dropdown-menu {
border-color: #334155;
background: #1e1e1e;
box-shadow: 0 10px 25px -5px rgba(0, 0, 0, 0.4);
}
.cfg-dropdown.open .cfg-dropdown-menu { display: block; }
.cfg-dropdown-item {
display: flex;
align-items: center;
padding: 8px 10px;
border-radius: 6px;
font-size: 0.875rem;
color: #334155;
cursor: pointer;
transition: background 0.15s ease;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.dark .cfg-dropdown-item { color: #cbd5e1; }
.cfg-dropdown-item:hover { background: #f1f5f9; }
.dark .cfg-dropdown-item:hover { background: rgba(255, 255, 255, 0.08); }
.cfg-dropdown-item.active {
background: rgba(74, 190, 110, 0.1);
color: #228547;
font-weight: 500;
}
.dark .cfg-dropdown-item.active {
background: rgba(74, 190, 110, 0.15);
color: #74E9A4;
}
/* API Key masking via CSS (avoids browser password prompts) */
.cfg-key-masked {
-webkit-text-security: disc;
text-security: disc;
}
/* Chat Input */
#chat-input {
resize: none; height: 42px; max-height: 180px;
overflow-y: hidden;
transition: border-color 0.2s ease;
}
/* Attachment Preview Bar */
.attachment-preview {
display: flex;
flex-wrap: wrap;
gap: 8px;
padding: 8px 0;
}
.attachment-preview.hidden { display: none; }
.att-thumb {
position: relative;
width: 64px; height: 64px;
border-radius: 8px;
overflow: hidden;
border: 1px solid #e2e8f0;
flex-shrink: 0;
}
.dark .att-thumb { border-color: rgba(255,255,255,0.1); }
.att-thumb img {
width: 100%; height: 100%;
object-fit: cover;
}
.att-chip {
position: relative;
display: flex;
align-items: center;
gap: 6px;
padding: 6px 28px 6px 10px;
border-radius: 8px;
background: #f1f5f9;
border: 1px solid #e2e8f0;
font-size: 12px;
color: #475569;
max-width: 180px;
}
.dark .att-chip { background: rgba(255,255,255,0.05); border-color: rgba(255,255,255,0.1); color: #94a3b8; }
.att-uploading { opacity: 0.6; pointer-events: none; }
.att-name {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.att-remove {
position: absolute;
top: -4px; right: -4px;
width: 18px; height: 18px;
border-radius: 50%;
background: #ef4444;
color: #fff;
border: none;
font-size: 12px;
line-height: 18px;
text-align: center;
cursor: pointer;
padding: 0;
opacity: 0;
transition: opacity 0.15s;
}
.att-thumb:hover .att-remove,
.att-chip:hover .att-remove { opacity: 1; }
/* Drag-over highlight */
.drag-over {
background: rgba(74, 190, 110, 0.08) !important;
border-color: #4ABE6E !important;
}
/* User message attachments */
.user-msg-attachments {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-bottom: 6px;
}
.user-msg-image {
max-width: 200px;
max-height: 160px;
border-radius: 8px;
object-fit: cover;
cursor: pointer;
}
.user-msg-image:hover { opacity: 0.9; }
.user-msg-file {
display: flex;
align-items: center;
gap: 6px;
padding: 4px 10px;
border-radius: 6px;
background: rgba(255,255,255,0.15);
font-size: 12px;
}
/* Placeholder Cards */
.placeholder-card {
transition: transform 0.2s ease, box-shadow 0.2s ease;
}
.placeholder-card:hover {
transform: translateY(-2px);
box-shadow: 0 8px 25px -5px rgba(0, 0, 0, 0.1);
}
/* Slash Command Menu */
.slash-menu {
position: absolute;
bottom: calc(100% + 6px);
left: 0;
right: 0;
max-height: 320px;
overflow-y: auto;
background: #fff;
border: 1px solid #e2e8f0;
border-radius: 12px;
box-shadow: 0 8px 30px -6px rgba(0, 0, 0, 0.1), 0 2px 8px -2px rgba(0, 0, 0, 0.04);
z-index: 50;
padding: 4px;
animation: slashMenuIn 0.15s ease-out;
}
.slash-menu.hidden { display: none; }
@keyframes slashMenuIn {
from { opacity: 0; transform: translateY(6px); }
to { opacity: 1; transform: translateY(0); }
}
.slash-menu-header {
padding: 6px 10px 4px;
font-size: 11px;
font-weight: 600;
color: #94a3b8;
text-transform: uppercase;
letter-spacing: 0.05em;
}
.slash-menu-item {
display: flex;
align-items: center;
justify-content: space-between;
padding: 8px 10px;
border-radius: 8px;
cursor: pointer;
transition: background 0.12s ease;
}
.slash-menu-item:hover,
.slash-menu-item.active {
background: #EDFDF3;
}
.slash-menu-item .cmd {
font-size: 13px;
font-weight: 500;
color: #334155;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, monospace;
}
.slash-menu-item.active .cmd {
color: #228547;
}
.slash-menu-item .desc {
font-size: 12px;
color: #94a3b8;
margin-left: 12px;
white-space: nowrap;
}
/* Dark mode */
.dark .slash-menu {
background: #1A1A1A;
border-color: rgba(255, 255, 255, 0.1);
box-shadow: 0 8px 30px -6px rgba(0, 0, 0, 0.35), 0 2px 8px -2px rgba(0, 0, 0, 0.15);
}
.dark .slash-menu-header {
color: #64748b;
}
.dark .slash-menu-item:hover,
.dark .slash-menu-item.active {
background: rgba(74, 190, 110, 0.1);
}
.dark .slash-menu-item .cmd {
color: #e2e8f0;
}
.dark .slash-menu-item.active .cmd {
color: #4ABE6E;
}
.dark .slash-menu-item .desc {
color: #64748b;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,179 +0,0 @@
# encoding:utf-8
"""
wechat channel
"""
import io
import json
import os
import threading
import time
from queue import Empty
from typing import Any
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wechat.wcf_message import WechatfMessage
from common.log import logger
from common.singleton import singleton
from common.utils import *
from config import conf, get_appdata_dir
from wcferry import Wcf, WxMsg
@singleton
class WechatfChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []
def __init__(self):
super().__init__()
self.NOT_SUPPORT_REPLYTYPE = []
# 使用字典存储最近消息,用于去重
self.received_msgs = {}
# 初始化wcferry客户端
self.wcf = Wcf()
self.wxid = None # 登录后会被设置为当前登录用户的wxid
def startup(self):
"""
启动通道
"""
try:
# wcferry会自动唤起微信并登录
self.wxid = self.wcf.get_self_wxid()
self.name = self.wcf.get_user_info().get("name")
logger.info(f"微信登录成功当前用户ID: {self.wxid}, 用户名:{self.name}")
self.contact_cache = ContactCache(self.wcf)
self.contact_cache.update()
# 启动消息接收
self.wcf.enable_receiving_msg()
# 创建消息处理线程
t = threading.Thread(target=self._process_messages, name="WeChatThread", daemon=True)
t.start()
except Exception as e:
logger.error(f"微信通道启动失败: {e}")
raise e
def _process_messages(self):
"""
处理消息队列
"""
while True:
try:
msg = self.wcf.get_msg()
if msg:
self._handle_message(msg)
except Empty:
continue
except Exception as e:
logger.error(f"处理消息失败: {e}")
continue
def _handle_message(self, msg: WxMsg):
"""
处理单条消息
"""
try:
# 构造消息对象
cmsg = WechatfMessage(self, msg)
# 消息去重
if cmsg.msg_id in self.received_msgs:
return
self.received_msgs[cmsg.msg_id] = time.time()
# 清理过期消息ID
self._clean_expired_msgs()
logger.debug(f"收到消息: {msg}")
context = self._compose_context(cmsg.ctype, cmsg.content,
isgroup=cmsg.is_group,
msg=cmsg)
if context:
self.produce(context)
except Exception as e:
logger.error(f"处理消息失败: {e}")
def _clean_expired_msgs(self, expire_time: float = 60):
"""
清理过期的消息ID
"""
now = time.time()
for msg_id in list(self.received_msgs.keys()):
if now - self.received_msgs[msg_id] > expire_time:
del self.received_msgs[msg_id]
def send(self, reply: Reply, context: Context):
"""
发送消息
"""
receiver = context["receiver"]
if not receiver:
logger.error("receiver is empty")
return
try:
if reply.type == ReplyType.TEXT:
# 处理@信息
at_list = []
if context.get("isgroup"):
if context["msg"].actual_user_id:
at_list = [context["msg"].actual_user_id]
at_str = ",".join(at_list) if at_list else ""
self.wcf.send_text(reply.content, receiver, at_str)
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
self.wcf.send_text(reply.content, receiver)
else:
logger.error(f"暂不支持的消息类型: {reply.type}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
def close(self):
"""
关闭通道
"""
try:
self.wcf.cleanup()
except Exception as e:
logger.error(f"关闭通道失败: {e}")
class ContactCache:
def __init__(self, wcf):
"""
wcf: 一个 wcfferry.client.Wcf 实例
"""
self.wcf = wcf
self._contact_map = {} # 形如 {wxid: {完整联系人信息}}
def update(self):
"""
更新缓存:调用 get_contacts()
再把 wcf.contacts 构建成 {wxid: {完整信息}} 的字典
"""
self.wcf.get_contacts()
self._contact_map.clear()
for item in self.wcf.contacts:
wxid = item.get('wxid')
if wxid: # 确保有 wxid 字段
self._contact_map[wxid] = item
def get_contact(self, wxid: str) -> dict:
"""
返回该 wxid 对应的完整联系人 dict
如果没找到就返回 None
"""
return self._contact_map.get(wxid)
def get_name_by_wxid(self, wxid: str) -> str:
"""
通过wxid获取成员/群名称
"""
contact = self.get_contact(wxid)
if contact:
return contact.get('name', '')
return ''

View File

@@ -1,58 +0,0 @@
# encoding:utf-8
"""
wechat channel message
"""
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from wcferry import WxMsg
class WechatfMessage(ChatMessage):
"""
微信消息封装类
"""
def __init__(self, channel, wcf_msg: WxMsg, is_group=False):
"""
初始化消息对象
:param wcf_msg: wcferry消息对象
:param is_group: 是否是群消息
"""
super().__init__(wcf_msg)
self.msg_id = wcf_msg.id
self.create_time = wcf_msg.ts # 使用消息时间戳
self.is_group = is_group or wcf_msg._is_group
self.wxid = channel.wxid
self.name = channel.name
# 解析消息类型
if wcf_msg.is_text():
self.ctype = ContextType.TEXT
self.content = wcf_msg.content
else:
raise NotImplementedError(f"Unsupported message type: {wcf_msg.type}")
# 设置发送者和接收者信息
self.from_user_id = self.wxid if wcf_msg.sender == self.wxid else wcf_msg.sender
self.from_user_nickname = self.name if wcf_msg.sender == self.wxid else channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
self.to_user_id = self.wxid
self.to_user_nickname = self.name
self.other_user_id = wcf_msg.sender
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
# 群消息特殊处理
if self.is_group:
self.other_user_id = wcf_msg.roomid
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.roomid)
self.actual_user_id = wcf_msg.sender
self.actual_user_nickname = channel.wcf.get_alias_in_chatroom(wcf_msg.sender, wcf_msg.roomid)
if not self.actual_user_nickname: # 群聊获取不到企微号成员昵称,这里尝试从联系人缓存去获取
self.actual_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
self.room_id = wcf_msg.roomid
self.is_at = wcf_msg.is_at(self.wxid) # 是否被@当前登录用户
# 判断是否是自己发送的消息
self.my_msg = wcf_msg.from_self()

View File

@@ -1,309 +0,0 @@
# encoding:utf-8
"""
wechat channel
"""
import io
import json
import os
import threading
import time
import requests
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel import chat_channel
from channel.wechat.wechat_message import *
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.time_check import time_checker
from common.utils import convert_webp_to_png, remove_markdown_symbol
from config import conf, get_appdata_dir
from lib import itchat
from lib.itchat.content import *
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
def handler_single_msg(msg):
try:
cmsg = WechatMessage(msg, False)
except NotImplementedError as e:
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
return None
WechatChannel().handle_single(cmsg)
return None
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
def handler_group_msg(msg):
try:
cmsg = WechatMessage(msg, True)
except NotImplementedError as e:
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
return None
WechatChannel().handle_group(cmsg)
return None
def _check(func):
def wrapper(self, cmsg: ChatMessage):
msgId = cmsg.msg_id
if msgId in self.receivedMsgs:
logger.info("Wechat message {} already received, ignore".format(msgId))
return
self.receivedMsgs[msgId] = True
create_time = cmsg.create_time # 消息时间戳
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
if cmsg.my_msg and not cmsg.is_group:
logger.debug("[WX]my message {} skipped".format(msgId))
return
return func(self, cmsg)
return wrapper
# 可用的二维码生成接口
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
def qrCallback(uuid, status, qrcode):
# logger.debug("qrCallback: {} {}".format(uuid,status))
if status == "0":
try:
from PIL import Image
img = Image.open(io.BytesIO(qrcode))
_thread = threading.Thread(target=img.show, args=("QRCode",))
_thread.setDaemon(True)
_thread.start()
except Exception as e:
pass
import qrcode
url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
print(qr_api2)
print(qr_api1)
_send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1])
qr = qrcode.QRCode(border=1)
qr.add_data(url)
qr.make(fit=True)
try:
qr.print_ascii(invert=True)
except UnicodeEncodeError:
print("ASCII QR code printing failed due to encoding issues.")
@singleton
class WechatChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []
def __init__(self):
super().__init__()
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
self.auto_login_times = 0
def startup(self):
try:
time.sleep(3)
logger.error("""[WechatChannel] 当前channel暂不可用目前支持的channel有:
1. terminal: 终端
2. wechatmp: 个人公众号
3. wechatmp_service: 企业公众号
4. wechatcom_app: 企微自建应用
5. dingtalk: 钉钉
6. feishu: 飞书
7. web: 网页
8. wcf: wechat (需Windows环境参考 https://github.com/zhayujie/chatgpt-on-wechat/pull/2562 )
可修改 config.json 配置文件的 channel_type 字段进行切换""")
# itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# # login by scan QRCode
# hotReload = conf().get("hot_reload", False)
# status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
# itchat.auto_login(
# enableCmdQR=2,
# hotReload=hotReload,
# statusStorageDir=status_path,
# qrCallback=qrCallback,
# exitCallback=self.exitCallback,
# loginCallback=self.loginCallback
# )
# self.user_id = itchat.instance.storageClass.userName
# self.name = itchat.instance.storageClass.nickName
# logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# # start message listener
# itchat.run()
except Exception as e:
logger.exception(e)
def exitCallback(self):
try:
from common.linkai_client import chat_client
if chat_client.client_id and conf().get("use_linkai"):
_send_logout()
time.sleep(2)
self.auto_login_times += 1
if self.auto_login_times < 100:
chat_channel.handler_pool._shutdown = False
self.startup()
except Exception as e:
pass
def loginCallback(self):
logger.debug("Login success")
_send_login_success()
# handle_* 系列函数处理收到的消息后构造Context然后传入produce函数中处理Context和发送回复
# Context包含了消息的所有信息包括以下属性
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
# content 消息内容如果是TEXT类型content就是文本内容如果是VOICE类型content就是语音文件名如果是IMAGE_CREATE类型content就是图片生成命令
# kwargs 附加参数字典包含以下的key
# session_id: 会话id
# isgroup: 是否是群聊
# receiver: 需要回复的对象
# msg: ChatMessage消息对象
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
# desire_rtype: 希望回复类型默认是文本回复设置为ReplyType.VOICE是语音回复
@time_checker
@_check
def handle_single(self, cmsg: ChatMessage):
# filter system message
if cmsg.other_user_id in ["weixin"]:
return
if cmsg.ctype == ContextType.VOICE:
if conf().get("speech_recognition") != True:
return
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)
@time_checker
@_check
def handle_group(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE:
if conf().get("group_speech_recognition") != True:
return
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
pass
elif cmsg.ctype == ContextType.FILE:
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg, no_need_at=conf().get("no_need_at", False))
if context:
self.produce(context)
# 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
if reply.type == ReplyType.TEXT:
reply.content = remove_markdown_symbol(reply.content)
itchat.send(reply.content, toUserName=receiver)
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = remove_markdown_symbol(reply.content)
itchat.send(reply.content, toUserName=receiver)
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
logger.debug(f"[WX] start download image, img_url={img_url}")
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
size = 0
for block in pic_res.iter_content(1024):
size += len(block)
image_storage.write(block)
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
image_storage.seek(0)
if ".webp" in img_url:
try:
image_storage = convert_webp_to_png(image_storage)
except Exception as e:
logger.error(f"Failed to convert image: {e}")
return
itchat.send_image(image_storage, toUserName=receiver)
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info("[WX] sendImage, receiver={}".format(receiver))
elif reply.type == ReplyType.FILE: # 新增文件回复类型
file_storage = reply.content
itchat.send_file(file_storage, toUserName=receiver)
logger.info("[WX] sendFile, receiver={}".format(receiver))
elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
video_storage = reply.content
itchat.send_video(video_storage, toUserName=receiver)
logger.info("[WX] sendFile, receiver={}".format(receiver))
elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
video_url = reply.content
logger.debug(f"[WX] start download video, video_url={video_url}")
video_res = requests.get(video_url, stream=True)
video_storage = io.BytesIO()
size = 0
for block in video_res.iter_content(1024):
size += len(block)
video_storage.write(block)
logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
video_storage.seek(0)
itchat.send_video(video_storage, toUserName=receiver)
logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
def _send_login_success():
try:
from common.linkai_client import chat_client
if chat_client.client_id:
chat_client.send_login_success()
except Exception as e:
pass
def _send_logout():
try:
from common.linkai_client import chat_client
if chat_client.client_id:
chat_client.send_logout()
except Exception as e:
pass
def _send_qr_code(qrcode_list: list):
try:
from common.linkai_client import chat_client
if chat_client.client_id:
chat_client.send_qrcode(qrcode_list)
except Exception as e:
pass

View File

@@ -1,124 +0,0 @@
import re
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.tmp_dir import TmpDir
from lib import itchat
from lib.itchat.content import *
class WechatMessage(ChatMessage):
def __init__(self, itchat_msg, is_group=False):
super().__init__(itchat_msg)
self.msg_id = itchat_msg["MsgId"]
self.create_time = itchat_msg["CreateTime"]
self.is_group = is_group
notes_join_group = ["加入群聊", "加入了群聊", "invited", "joined"] # 可通过添加对应语言的加入群聊通知中的关键词适配更多
notes_bot_join_group = ["邀请你", "invited you", "You've joined", "你通过扫描"]
notes_exit_group = ["移出了群聊", "removed"] # 可通过添加对应语言的踢出群聊通知中的关键词适配更多
notes_patpat = ["拍了拍我", "tickled my", "tickled me"] # 可通过添加对应语言的拍一拍通知中的关键词适配更多
if itchat_msg["Type"] == TEXT:
self.ctype = ContextType.TEXT
self.content = itchat_msg["Text"]
elif itchat_msg["Type"] == VOICE:
self.ctype = ContextType.VOICE
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
if is_group:
if any(note_bot_join_group in itchat_msg["Content"] for note_bot_join_group in notes_bot_join_group): # 邀请机器人加入群聊
logger.warn("机器人加入群聊消息,不处理~")
pass
elif any(note_join_group in itchat_msg["Content"] for note_join_group in notes_join_group): # 若有任何在notes_join_group列表中的字符串出现在NOTE中
# 这里只能得到nickname actual_user_id还是机器人的id
if "加入群聊" not in itchat_msg["Content"]:
self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"]
if "invited" in itchat_msg["Content"]: # 匹配英文信息
self.actual_user_nickname = re.findall(r'invited\s+(.+?)\s+to\s+the\s+group\s+chat', itchat_msg["Content"])[0]
elif "joined" in itchat_msg["Content"]: # 匹配通过二维码加入的英文信息
self.actual_user_nickname = re.findall(r'"(.*?)" joined the group chat via the QR Code shared by', itchat_msg["Content"])[0]
elif "加入了群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
elif "加入群聊" in itchat_msg["Content"]:
self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif any(note_exit_group in itchat_msg["Content"] for note_exit_group in notes_exit_group): # 若有任何在notes_exit_group列表中的字符串出现在NOTE中
self.ctype = ContextType.EXIT_GROUP
self.content = itchat_msg["Content"]
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"]
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif "tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
else:
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
self.ctype = ContextType.ACCEPT_FRIEND
self.content = itchat_msg["Content"]
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"]
else:
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
elif itchat_msg["Type"] == ATTACHMENT:
self.ctype = ContextType.FILE
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == SHARING:
self.ctype = ContextType.SHARING
self.content = itchat_msg.get("Url")
else:
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"]
user_id = itchat.instance.storageClass.userName
nickname = itchat.instance.storageClass.nickName
# 虽然from_user_id和to_user_id用的少但是为了保持一致性还是要填充一下
# 以下很繁琐,一句话总结:能填的都填了。
if self.from_user_id == user_id:
self.from_user_nickname = nickname
if self.to_user_id == user_id:
self.to_user_nickname = nickname
try: # 陌生人时候, User字段可能不存在
# my_msg 为True是表示是自己发送的消息
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
self.other_user_id = itchat_msg["User"]["UserName"]
self.other_user_nickname = itchat_msg["User"]["NickName"]
if self.other_user_id == self.from_user_id:
self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname
if itchat_msg["User"].get("Self"):
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id:
self.other_user_id = self.to_user_id
else:
self.other_user_id = self.from_user_id
if self.is_group:
self.is_at = itchat_msg["IsAt"]
self.actual_user_id = itchat_msg["ActualUserName"]
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
self.actual_user_nickname = itchat_msg["ActualNickName"]

View File

@@ -1,129 +0,0 @@
# encoding:utf-8
"""
wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty
"""
import asyncio
import base64
import os
import time
from wechaty import Contact, Wechaty
from wechaty.user import Message
from wechaty_puppet import FileBox
from bridge.context import *
from bridge.context import Context
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger
from common.singleton import singleton
from config import conf
try:
from voice.audio_convert import any_to_sil
except Exception as e:
pass
@singleton
class WechatyChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []
def __init__(self):
super().__init__()
def startup(self):
config = conf()
token = config.get("wechaty_puppet_service_token")
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
asyncio.run(self.main())
async def main(self):
loop = asyncio.get_event_loop()
# 将asyncio的loop传入处理线程
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
self.bot = Wechaty()
self.bot.on("login", self.on_login)
self.bot.on("message", self.on_message)
await self.bot.start()
async def on_login(self, contact: Contact):
self.user_id = contact.contact_id
self.name = contact.name
logger.info("[WX] login user={}".format(contact))
# 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver_id = context["receiver"]
loop = asyncio.get_event_loop()
if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
else:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE:
voiceLength = None
file_path = reply.content
sil_file = os.path.splitext(file_path)[0] + ".sil"
voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000:
voiceLength = 60000
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
if voiceLength is not None:
msg.metadata["voiceLength"] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
try:
os.remove(file_path)
if sil_file != file_path:
os.remove(sil_file)
except Exception as e:
pass
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver))
async def on_message(self, msg: Message):
"""
listen for message event
"""
try:
cmsg = await WechatyMessage(msg)
except NotImplementedError as e:
logger.debug("[WX] {}".format(e))
return
except Exception as e:
logger.exception("[WX] {}".format(e))
return
logger.debug("[WX] message:{}".format(cmsg))
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
isgroup = room is not None
ctype = cmsg.ctype
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context:
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
self.produce(context)

View File

@@ -1,89 +0,0 @@
import asyncio
import re
from wechaty import MessageType
from wechaty.user import Message
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.tmp_dir import TmpDir
class aobject(object):
"""Inheriting this class allows you to define an async __init__.
So you can create objects by doing something like `await MyClass(params)`
"""
async def __new__(cls, *a, **kw):
instance = super().__new__(cls)
await instance.__init__(*a, **kw)
return instance
async def __init__(self):
pass
class WechatyMessage(ChatMessage, aobject):
async def __init__(self, wechaty_msg: Message):
super().__init__(wechaty_msg)
room = wechaty_msg.room()
self.msg_id = wechaty_msg.message_id
self.create_time = wechaty_msg.payload.timestamp
self.is_group = room is not None
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
self.ctype = ContextType.TEXT
self.content = wechaty_msg.text()
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
self.ctype = ContextType.VOICE
voice_file = await wechaty_msg.to_file_box()
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
def func():
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
self._prepare_fn = func
else:
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id
self.from_user_nickname = from_contact.name
# group中的from和towechaty跟itchat含义不一样
# wecahty: from是消息实际发送者, to:所在群
# itchat: 如果是你发送群消息from和to是你自己和所在群如果是别人发群消息from和to是所在群和你自己
# 但这个差别不影响逻辑group中只使用到1.用from来判断是否是自己发的2.actual_user_id来判断实际发送用户
if self.is_group:
self.to_user_id = room.room_id
self.to_user_nickname = await room.topic()
else:
to_contact = wechaty_msg.to()
self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name
if self.is_group or wechaty_msg.is_self(): # 如果是群消息other_user设置为群如果是私聊消息而且自己发的就设置成对方。
self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname
else:
self.other_user_id = self.from_user_id
self.other_user_nickname = self.from_user_nickname
if self.is_group: # wechaty群聊中实际发送用户就是from_user
self.is_at = await wechaty_msg.mention_self()
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx这里做一下兼容
name = wechaty_msg.wechaty.user_self().name
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
if re.search(pattern, self.content):
logger.debug(f"wechaty message {self.msg_id} include at")
self.is_at = True
self.actual_user_id = self.from_user_id
self.actual_user_nickname = self.from_user_nickname

View File

@@ -1,6 +1,7 @@
# -*- coding=utf-8 -*-
import io
import os
import sys
import time
import requests
@@ -35,9 +36,9 @@ class WechatComAppChannel(ChatChannel):
self.agent_id = conf().get("wechatcomapp_agent_id")
self.token = conf().get("wechatcomapp_token")
self.aes_key = conf().get("wechatcomapp_aes_key")
print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
self._http_server = None
logger.info(
"[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
"[wechatcom] Initializing WeCom app channel, corp_id: {}, agent_id: {}".format(self.corp_id, self.agent_id)
)
self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
self.client = WechatComAppClient(self.corp_id, self.secret)
@@ -47,7 +48,28 @@ class WechatComAppChannel(ChatChannel):
urls = ("/wxcomapp/?", "channel.wechatcom.wechatcomapp_channel.Query")
app = web.application(urls, globals(), autoreload=False)
port = conf().get("wechatcomapp_port", 9898)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
logger.info("[wechatcom] ✅ WeCom app channel started successfully")
logger.info("[wechatcom] 📡 Listening on http://0.0.0.0:{}/wxcomapp/".format(port))
logger.info("[wechatcom] 🤖 Ready to receive messages")
# Build WSGI app with middleware (same as runsimple but without print)
func = web.httpserver.StaticMiddleware(app.wsgifunc())
func = web.httpserver.LogMiddleware(func)
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
self._http_server = server
try:
server.start()
except (KeyboardInterrupt, SystemExit):
server.stop()
def stop(self):
if self._http_server:
try:
self._http_server.stop()
logger.info("[wechatcom] HTTP server stopped")
except Exception as e:
logger.warning(f"[wechatcom] Error stopping HTTP server: {e}")
self._http_server = None
def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
@@ -74,6 +96,10 @@ class WechatComAppChannel(ChatChannel):
response = self.client.media.upload("voice", open(path, "rb"))
logger.debug("[wechatcom] upload voice response: {}".format(response))
media_ids.append(response["media_id"])
except ImportError as e:
logger.error("[wechatcom] voice conversion failed: {}".format(e))
logger.error("[wechatcom] please install pydub: pip install pydub")
return
except WeChatClientException as e:
logger.error("[wechatcom] upload voice failed: {}".format(e))
return

View File

@@ -1,6 +1,6 @@
# 微信公众号channel
鉴于个人微信号在服务器上通过itchat登录有封号风险这里新增了微信公众号channel提供无风险的服务。
微信公众号channel提供稳定的服务。
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
## 使用方法(订阅号,服务号类似)

View File

@@ -21,7 +21,11 @@ from common.log import logger
from common.singleton import singleton
from common.utils import split_string_by_utf8_length, remove_markdown_symbol
from config import conf
from voice.audio_convert import any_to_mp3, split_audio
try:
from voice.audio_convert import any_to_mp3, split_audio
except ImportError as e:
logger.debug("import voice.audio_convert failed, voice features will not be supported: {}".format(e))
# If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer
@@ -37,6 +41,7 @@ class WechatMPChannel(ChatChannel):
super().__init__()
self.passive_reply = passive_reply
self.NOT_SUPPORT_REPLYTYPE = []
self._http_server = None
appid = conf().get("wechatmp_app_id")
secret = conf().get("wechatmp_app_secret")
token = conf().get("wechatmp_token")
@@ -65,7 +70,23 @@ class WechatMPChannel(ChatChannel):
urls = ("/wx", "channel.wechatmp.active_reply.Query")
app = web.application(urls, globals(), autoreload=False)
port = conf().get("wechatmp_port", 8080)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
func = web.httpserver.StaticMiddleware(app.wsgifunc())
func = web.httpserver.LogMiddleware(func)
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
self._http_server = server
try:
server.start()
except (KeyboardInterrupt, SystemExit):
server.stop()
def stop(self):
if self._http_server:
try:
self._http_server.stop()
logger.info("[wechatmp] HTTP server stopped")
except Exception as e:
logger.warning(f"[wechatmp] Error stopping HTTP server: {e}")
self._http_server = None
def start_loop(self, loop):
asyncio.set_event_loop(loop)
@@ -85,26 +106,31 @@ class WechatMPChannel(ChatChannel):
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
self.cache_dict[receiver].append(("text", reply_text))
elif reply.type == ReplyType.VOICE:
voice_file_path = reply.content
duration, files = split_audio(voice_file_path, 60 * 1000)
if len(files) > 1:
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
try:
voice_file_path = reply.content
duration, files = split_audio(voice_file_path, 60 * 1000)
if len(files) > 1:
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
for path in files:
# support: <2M, <60s, mp3/wma/wav/amr
try:
with open(path, "rb") as f:
response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response))
f_size = os.fstat(f.fileno()).st_size
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
# todo check media_id
except WeChatClientException as e:
logger.error("[wechatmp] upload voice failed: {}".format(e))
return
media_id = response["media_id"]
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver].append(("voice", media_id))
for path in files:
# support: <2M, <60s, mp3/wma/wav/amr
try:
with open(path, "rb") as f:
response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response))
f_size = os.fstat(f.fileno()).st_size
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
# todo check media_id
except WeChatClientException as e:
logger.error("[wechatmp] upload voice failed: {}".format(e))
return
media_id = response["media_id"]
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver].append(("voice", media_id))
except ImportError as e:
logger.error("[wechatmp] voice conversion failed: {}".format(e))
logger.error("[wechatmp] please install pydub: pip install pydub")
return
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
@@ -213,6 +239,10 @@ class WechatMPChannel(ChatChannel):
logger.debug("[wechatcom] upload voice response: {}".format(response))
media_ids.append(response["media_id"])
os.remove(path)
except ImportError as e:
logger.error("[wechatmp] voice conversion failed: {}".format(e))
logger.error("[wechatmp] please install pydub: pip install pydub")
return
except WeChatClientException as e:
logger.error("[wechatmp] upload voice failed: {}".format(e))
return

View File

View File

@@ -0,0 +1,768 @@
"""
WeCom (企业微信) AI Bot channel via WebSocket long connection.
Supports:
- Single chat and group chat (text / image / file input & output)
- Scheduled task push via aibot_send_msg
- Heartbeat keep-alive and auto-reconnect
"""
import base64
import hashlib
import json
import math
import os
import threading
import time
import uuid
import requests
import websocket
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.wecom_bot.wecom_bot_message import WecomBotMessage
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.ws_client_compat import websocket_app_run_forever
from config import conf
WECOM_WS_URL = "wss://openws.work.weixin.qq.com"
HEARTBEAT_INTERVAL = 30
MEDIA_CHUNK_SIZE = 512 * 1024 # 512KB per chunk (before base64 encoding)
@singleton
class WecomBotChannel(ChatChannel):
def __init__(self):
super().__init__()
self.bot_id = ""
self.bot_secret = ""
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
self._ws = None
self._ws_thread = None
self._heartbeat_thread = None
self._connected = False
self._stop_event = threading.Event()
self._pending_responses = {} # req_id -> (threading.Event, result_holder)
self._pending_lock = threading.Lock()
self._stream_states = {} # req_id -> {"stream_id": str, "content": str}
conf()["group_name_white_list"] = ["ALL_GROUP"]
conf()["single_chat_prefix"] = [""]
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def startup(self):
self.bot_id = conf().get("wecom_bot_id", "")
self.bot_secret = conf().get("wecom_bot_secret", "")
if not self.bot_id or not self.bot_secret:
err = "[WecomBot] wecom_bot_id and wecom_bot_secret are required"
logger.error(err)
self.report_startup_error(err)
return
self._stop_event.clear()
self._start_ws()
def stop(self):
logger.info("[WecomBot] stop() called")
self._stop_event.set()
if self._ws:
try:
self._ws.close()
except Exception:
pass
self._ws = None
self._connected = False
# ------------------------------------------------------------------
# WebSocket connection
# ------------------------------------------------------------------
def _start_ws(self):
def _on_open(ws):
logger.info("[WecomBot] WebSocket connected, sending subscribe...")
self._send_subscribe()
def _on_message(ws, raw):
try:
data = json.loads(raw)
self._handle_ws_message(data)
except Exception as e:
logger.error(f"[WecomBot] Failed to handle ws message: {e}", exc_info=True)
def _on_error(ws, error):
logger.error(f"[WecomBot] WebSocket error: {error}")
def _on_close(ws, close_status_code, close_msg):
logger.warning(f"[WecomBot] WebSocket closed: status={close_status_code}, msg={close_msg}")
self._connected = False
if not self._stop_event.is_set():
logger.info("[WecomBot] Will reconnect in 5s...")
time.sleep(5)
if not self._stop_event.is_set():
self._start_ws()
self._ws = websocket.WebSocketApp(
WECOM_WS_URL,
on_open=_on_open,
on_message=_on_message,
on_error=_on_error,
on_close=_on_close,
)
def run_forever():
try:
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
except (SystemExit, KeyboardInterrupt):
logger.info("[WecomBot] WebSocket thread interrupted")
except Exception as e:
logger.error(f"[WecomBot] WebSocket run_forever error: {e}")
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
self._ws_thread.start()
self._ws_thread.join()
def _ws_send(self, data: dict):
if self._ws:
self._ws.send(json.dumps(data, ensure_ascii=False))
def _gen_req_id(self) -> str:
return uuid.uuid4().hex[:16]
# ------------------------------------------------------------------
# Subscribe & heartbeat
# ------------------------------------------------------------------
def _send_subscribe(self):
self._ws_send({
"cmd": "aibot_subscribe",
"headers": {"req_id": self._gen_req_id()},
"body": {
"bot_id": self.bot_id,
"secret": self.bot_secret,
},
})
def _start_heartbeat(self):
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
return
def heartbeat_loop():
while not self._stop_event.is_set() and self._connected:
try:
self._ws_send({
"cmd": "ping",
"headers": {"req_id": self._gen_req_id()},
})
except Exception as e:
logger.warning(f"[WecomBot] Heartbeat send failed: {e}")
break
self._stop_event.wait(HEARTBEAT_INTERVAL)
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
self._heartbeat_thread.start()
# ------------------------------------------------------------------
# Incoming message dispatch
# ------------------------------------------------------------------
def _send_and_wait(self, data: dict, timeout: float = 15) -> dict:
"""Send a ws message and wait for the matching response by req_id."""
req_id = data.get("headers", {}).get("req_id", "")
event = threading.Event()
holder = {"data": None}
with self._pending_lock:
self._pending_responses[req_id] = (event, holder)
self._ws_send(data)
event.wait(timeout=timeout)
with self._pending_lock:
self._pending_responses.pop(req_id, None)
return holder["data"] or {}
def _handle_ws_message(self, data: dict):
cmd = data.get("cmd", "")
errcode = data.get("errcode")
req_id = data.get("headers", {}).get("req_id", "")
# Check if this is a response to a pending request
if req_id:
with self._pending_lock:
pending = self._pending_responses.get(req_id)
if pending:
event, holder = pending
holder["data"] = data
event.set()
return
# Subscribe response (only handle once before connected)
if errcode is not None and cmd == "":
if not self._connected:
if errcode == 0:
logger.info("[WecomBot] ✅ Subscribe success")
self._connected = True
self._start_heartbeat()
self.report_startup_success()
else:
errmsg = data.get("errmsg", "unknown error")
logger.error(f"[WecomBot] Subscribe failed: errcode={errcode}, errmsg={errmsg}")
self.report_startup_error(errmsg)
return
if cmd == "aibot_msg_callback":
self._handle_msg_callback(data)
elif cmd == "aibot_event_callback":
self._handle_event_callback(data)
elif cmd == "":
if errcode and errcode != 0:
logger.warning(f"[WecomBot] Response error: {data}")
# ------------------------------------------------------------------
# Message callback
# ------------------------------------------------------------------
def _handle_msg_callback(self, data: dict):
body = data.get("body", {})
req_id = data.get("headers", {}).get("req_id", "")
msg_id = body.get("msgid", "")
if self.received_msgs.get(msg_id):
logger.debug(f"[WecomBot] Duplicate msg filtered: {msg_id}")
return
self.received_msgs[msg_id] = True
chattype = body.get("chattype", "single")
is_group = chattype == "group"
try:
wecom_msg = WecomBotMessage(body, is_group=is_group)
except NotImplementedError as e:
logger.warning(f"[WecomBot] {e}")
return
except Exception as e:
logger.error(f"[WecomBot] Failed to parse message: {e}", exc_info=True)
return
wecom_msg.req_id = req_id
# File cache logic (same pattern as feishu)
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
if is_group:
if conf().get("group_shared_session", True):
session_id = body.get("chatid", "")
else:
session_id = wecom_msg.from_user_id + "_" + body.get("chatid", "")
else:
session_id = wecom_msg.from_user_id
if wecom_msg.ctype == ContextType.IMAGE:
if hasattr(wecom_msg, "image_path") and wecom_msg.image_path:
file_cache.add(session_id, wecom_msg.image_path, file_type="image")
logger.info(f"[WecomBot] Image cached for session {session_id}")
return
if wecom_msg.ctype == ContextType.FILE:
wecom_msg.prepare()
file_cache.add(session_id, wecom_msg.content, file_type="file")
logger.info(f"[WecomBot] File cached for session {session_id}: {wecom_msg.content}")
return
if wecom_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
file_refs = []
for fi in cached_files:
ftype = fi["type"]
fpath = fi["path"]
if ftype == "image":
file_refs.append(f"[图片: {fpath}]")
elif ftype == "video":
file_refs.append(f"[视频: {fpath}]")
else:
file_refs.append(f"[文件: {fpath}]")
wecom_msg.content = wecom_msg.content + "\n" + "\n".join(file_refs)
logger.info(f"[WecomBot] Attached {len(cached_files)} cached file(s)")
file_cache.clear(session_id)
context = self._compose_context(
wecom_msg.ctype,
wecom_msg.content,
isgroup=is_group,
msg=wecom_msg,
no_need_at=True,
)
if context:
if req_id:
context["on_event"] = self._make_stream_callback(req_id)
self.produce(context)
# ------------------------------------------------------------------
# Event callback
# ------------------------------------------------------------------
def _handle_event_callback(self, data: dict):
body = data.get("body", {})
event = body.get("event", {})
event_type = event.get("eventtype", "")
if event_type == "enter_chat":
logger.info(f"[WecomBot] User entered chat: {body.get('from', {}).get('userid')}")
elif event_type == "disconnected_event":
logger.warning("[WecomBot] Received disconnected_event, another connection took over")
else:
logger.debug(f"[WecomBot] Event: {event_type}")
# ------------------------------------------------------------------
# Stream callback (for agent on_event)
# ------------------------------------------------------------------
def _make_stream_callback(self, req_id: str):
"""Build an on_event callback that pushes agent stream deltas to wecom via stream message.
All intermediate segments (thinking before tool calls) and the final answer
are accumulated into a single stream message, separated by '---'.
"""
stream_id = uuid.uuid4().hex[:16]
self._stream_states[req_id] = {
"stream_id": stream_id,
"committed": "", # finalized content from previous segments
"current": "", # current segment being streamed
}
def _push_stream(state: dict):
"""Push current stream content to wecom."""
self._ws_send({
"cmd": "aibot_respond_msg",
"headers": {"req_id": req_id},
"body": {
"msgtype": "stream",
"stream": {
"id": state["stream_id"],
"finish": False,
"content": state["committed"] + state["current"],
},
},
})
def on_event(event: dict):
event_type = event.get("type")
data = event.get("data", {})
state = self._stream_states.get(req_id)
if not state:
return
if event_type == "turn_start":
state["current"] = ""
elif event_type == "message_update":
delta = data.get("delta", "")
if delta:
state["current"] += delta
_push_stream(state)
elif event_type == "message_end":
tool_calls = data.get("tool_calls", [])
if tool_calls:
if state["current"].strip():
state["committed"] += state["current"].strip() + "\n\n---\n\n"
state["current"] = ""
else:
state["committed"] += state["current"]
state["current"] = ""
return on_event
# ------------------------------------------------------------------
# _compose_context (same pattern as feishu)
# ------------------------------------------------------------------
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
if cmsg.is_group:
if conf().get("group_shared_session", True):
context["session_id"] = cmsg.other_user_id
else:
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
else:
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
return context
# ------------------------------------------------------------------
# Send reply
# ------------------------------------------------------------------
def send(self, reply: Reply, context: Context):
msg = context.get("msg")
is_group = context.get("isgroup", False)
receiver = context.get("receiver", "")
# Determine req_id for responding or use send_msg for scheduled push
req_id = getattr(msg, "req_id", None) if msg else None
if reply.type == ReplyType.TEXT:
self._send_text(reply.content, receiver, is_group, req_id)
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
self._send_image(reply.content, receiver, is_group, req_id)
elif reply.type == ReplyType.FILE:
if hasattr(reply, "text_content") and reply.text_content:
self._send_text(reply.text_content, receiver, is_group, req_id)
time.sleep(0.3)
self._send_file(reply.content, receiver, is_group, req_id)
elif reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
self._send_file(reply.content, receiver, is_group, req_id, media_type="video")
else:
logger.warning(f"[WecomBot] Unsupported reply type: {reply.type}, falling back to text")
self._send_text(str(reply.content), receiver, is_group, req_id)
# ------------------------------------------------------------------
# Respond message (via websocket)
# ------------------------------------------------------------------
def _send_text(self, content: str, receiver: str, is_group: bool, req_id: str = None):
"""Send text/markdown reply. Reuses stream state if available (streaming mode)."""
if req_id:
state = self._stream_states.pop(req_id, None)
if state:
final_content = state["committed"]
stream_id = state["stream_id"]
else:
final_content = content
stream_id = uuid.uuid4().hex[:16]
self._ws_send({
"cmd": "aibot_respond_msg",
"headers": {"req_id": req_id},
"body": {
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": final_content,
},
},
})
else:
self._active_send_markdown(content, receiver, is_group)
def _send_image(self, img_path_or_url: str, receiver: str, is_group: bool, req_id: str = None):
"""Send image reply. Converts to JPG/PNG and compresses if >2MB."""
local_path = img_path_or_url
if local_path.startswith("file://"):
local_path = local_path[7:]
if local_path.startswith(("http://", "https://")):
try:
resp = requests.get(local_path, timeout=30)
resp.raise_for_status()
ct = resp.headers.get("Content-Type", "")
if "jpeg" in ct or "jpg" in ct:
ext = ".jpg"
elif "webp" in ct:
ext = ".webp"
elif "gif" in ct:
ext = ".gif"
else:
ext = ".png"
tmp_path = f"/tmp/wecom_img_{uuid.uuid4().hex[:8]}{ext}"
with open(tmp_path, "wb") as f:
f.write(resp.content)
logger.info(f"[WecomBot] Image downloaded: size={len(resp.content)}, "
f"content-type={ct}, path={tmp_path}")
local_path = tmp_path
except Exception as e:
logger.error(f"[WecomBot] Failed to download image for sending: {e}")
self._send_text("[Image send failed]", receiver, is_group, req_id)
return
if not os.path.exists(local_path):
logger.error(f"[WecomBot] Image file not found: {local_path}")
return
max_image_size = 2 * 1024 * 1024 # 2MB limit for image upload
local_path = self._ensure_image_format(local_path)
if not local_path:
self._send_text("[Image format conversion failed]", receiver, is_group, req_id)
return
if os.path.getsize(local_path) > max_image_size:
local_path = self._compress_image(local_path, max_image_size)
if not local_path:
self._send_text("[Image too large]", receiver, is_group, req_id)
return
file_size = os.path.getsize(local_path)
logger.info(f"[WecomBot] Uploading image: path={local_path}, size={file_size} bytes")
media_id = self._upload_media(local_path, "image")
if not media_id:
logger.error("[WecomBot] Failed to upload image")
self._send_text("[Image upload failed]", receiver, is_group, req_id)
return
if req_id:
self._ws_send({
"cmd": "aibot_respond_msg",
"headers": {"req_id": req_id},
"body": {
"msgtype": "image",
"image": {"media_id": media_id},
},
})
else:
self._ws_send({
"cmd": "aibot_send_msg",
"headers": {"req_id": self._gen_req_id()},
"body": {
"chatid": receiver,
"chat_type": 2 if is_group else 1,
"msgtype": "image",
"image": {"media_id": media_id},
},
})
@staticmethod
def _ensure_image_format(file_path: str) -> str:
"""Ensure image is JPG or PNG (the only formats wecom supports). Convert if needed."""
try:
from PIL import Image
img = Image.open(file_path)
fmt = (img.format or "").upper()
if fmt in ("JPEG", "PNG"):
# Already a supported format, but make sure the filename extension matches
ext = os.path.splitext(file_path)[1].lower()
if fmt == "JPEG" and ext in (".jpg", ".jpeg"):
return file_path
if fmt == "PNG" and ext == ".png":
return file_path
# Extension doesn't match — rename/copy with correct extension
correct_ext = ".jpg" if fmt == "JPEG" else ".png"
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}{correct_ext}"
img.save(out_path, fmt)
logger.info(f"[WecomBot] Image renamed: {file_path} -> {out_path} ({fmt})")
return out_path
# Unsupported format (WebP, GIF, BMP, etc.) — convert to PNG
if img.mode == "RGBA":
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}.png"
img.save(out_path, "PNG")
else:
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}.jpg"
img.convert("RGB").save(out_path, "JPEG", quality=90)
logger.info(f"[WecomBot] Image converted from {fmt} -> {out_path}")
return out_path
except Exception as e:
logger.error(f"[WecomBot] Image format check failed: {e}")
return file_path
@staticmethod
def _compress_image(file_path: str, max_bytes: int) -> str:
"""Compress image to fit within max_bytes. Returns new path or empty string."""
try:
from PIL import Image
img = Image.open(file_path)
if img.mode == "RGBA":
img = img.convert("RGB")
out_path = f"/tmp/wecom_compressed_{uuid.uuid4().hex[:8]}.jpg"
quality = 85
while quality >= 30:
img.save(out_path, "JPEG", quality=quality, optimize=True)
if os.path.getsize(out_path) <= max_bytes:
logger.info(f"[WecomBot] Image compressed: quality={quality}, "
f"size={os.path.getsize(out_path)} bytes")
return out_path
quality -= 10
# Still too large — resize
ratio = (max_bytes / os.path.getsize(out_path)) ** 0.5
new_size = (int(img.width * ratio), int(img.height * ratio))
img = img.resize(new_size, Image.LANCZOS)
img.save(out_path, "JPEG", quality=70, optimize=True)
if os.path.getsize(out_path) <= max_bytes:
logger.info(f"[WecomBot] Image compressed with resize: {new_size}, "
f"size={os.path.getsize(out_path)} bytes")
return out_path
logger.error(f"[WecomBot] Cannot compress image below {max_bytes} bytes")
return ""
except Exception as e:
logger.error(f"[WecomBot] Image compression failed: {e}")
return ""
def _send_file(self, file_path: str, receiver: str, is_group: bool,
req_id: str = None, media_type: str = "file"):
"""Send file/video reply by uploading media first."""
local_path = file_path
if local_path.startswith("file://"):
local_path = local_path[7:]
if local_path.startswith(("http://", "https://")):
try:
resp = requests.get(local_path, timeout=60)
resp.raise_for_status()
ext = os.path.splitext(local_path)[1] or ".bin"
tmp_path = f"/tmp/wecom_file_{uuid.uuid4().hex[:8]}{ext}"
with open(tmp_path, "wb") as f:
f.write(resp.content)
local_path = tmp_path
except Exception as e:
logger.error(f"[WecomBot] Failed to download file for sending: {e}")
return
if not os.path.exists(local_path):
logger.error(f"[WecomBot] File not found: {local_path}")
return
media_id = self._upload_media(local_path, media_type)
if not media_id:
logger.error(f"[WecomBot] Failed to upload {media_type}")
return
if req_id:
self._ws_send({
"cmd": "aibot_respond_msg",
"headers": {"req_id": req_id},
"body": {
"msgtype": media_type,
media_type: {"media_id": media_id},
},
})
else:
self._ws_send({
"cmd": "aibot_send_msg",
"headers": {"req_id": self._gen_req_id()},
"body": {
"chatid": receiver,
"chat_type": 2 if is_group else 1,
"msgtype": media_type,
media_type: {"media_id": media_id},
},
})
def _active_send_markdown(self, content: str, receiver: str, is_group: bool):
"""Proactively send markdown message (for scheduled tasks, no req_id)."""
self._ws_send({
"cmd": "aibot_send_msg",
"headers": {"req_id": self._gen_req_id()},
"body": {
"chatid": receiver,
"chat_type": 2 if is_group else 1,
"msgtype": "markdown",
"markdown": {"content": content},
},
})
# ------------------------------------------------------------------
# Media upload (chunked)
# ------------------------------------------------------------------
def _upload_media(self, file_path: str, media_type: str = "file") -> str:
"""
Upload a local file to wecom bot via chunked upload protocol.
Returns media_id on success, empty string on failure.
"""
if not os.path.exists(file_path):
logger.error(f"[WecomBot] Upload file not found: {file_path}")
return ""
file_size = os.path.getsize(file_path)
if file_size < 5:
logger.error(f"[WecomBot] File too small: {file_size} bytes")
return ""
filename = os.path.basename(file_path)
total_chunks = math.ceil(file_size / MEDIA_CHUNK_SIZE)
if total_chunks > 100:
logger.error(f"[WecomBot] Too many chunks: {total_chunks} > 100")
return ""
file_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for block in iter(lambda: f.read(8192), b""):
file_md5.update(block)
md5_hex = file_md5.hexdigest()
# 1. Init upload
init_resp = self._send_and_wait({
"cmd": "aibot_upload_media_init",
"headers": {"req_id": self._gen_req_id()},
"body": {
"type": media_type,
"filename": filename,
"total_size": file_size,
"total_chunks": total_chunks,
"md5": md5_hex,
},
}, timeout=15)
if init_resp.get("errcode") != 0:
logger.error(f"[WecomBot] Upload init failed: {init_resp}")
return ""
upload_id = init_resp.get("body", {}).get("upload_id")
if not upload_id:
logger.error("[WecomBot] Failed to get upload_id")
return ""
# 2. Upload chunks
with open(file_path, "rb") as f:
for idx in range(total_chunks):
chunk = f.read(MEDIA_CHUNK_SIZE)
b64_data = base64.b64encode(chunk).decode("utf-8")
chunk_resp = self._send_and_wait({
"cmd": "aibot_upload_media_chunk",
"headers": {"req_id": self._gen_req_id()},
"body": {
"upload_id": upload_id,
"chunk_index": idx,
"base64_data": b64_data,
},
}, timeout=30)
if chunk_resp.get("errcode") != 0:
logger.error(f"[WecomBot] Chunk {idx} upload failed: {chunk_resp}")
return ""
# 3. Finish upload
finish_resp = self._send_and_wait({
"cmd": "aibot_upload_media_finish",
"headers": {"req_id": self._gen_req_id()},
"body": {"upload_id": upload_id},
}, timeout=30)
if finish_resp.get("errcode") != 0:
logger.error(f"[WecomBot] Upload finish failed: {finish_resp}")
return ""
media_id = finish_resp.get("body", {}).get("media_id", "")
if media_id:
logger.info(f"[WecomBot] Media uploaded: media_id={media_id}")
else:
logger.error("[WecomBot] Failed to get media_id from finish response")
return media_id

View File

@@ -0,0 +1,216 @@
import os
import re
import base64
import requests
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.utils import expand_path
from config import conf
from Crypto.Cipher import AES
MAGIC_SIGNATURES = [
(b"%PDF", ".pdf"),
(b"\x89PNG\r\n\x1a\n", ".png"),
(b"\xff\xd8\xff", ".jpg"),
(b"GIF87a", ".gif"),
(b"GIF89a", ".gif"),
(b"RIFF", ".webp"), # RIFF....WEBP, further checked below
(b"PK\x03\x04", ".zip"), # zip / docx / xlsx / pptx
(b"\x1f\x8b", ".gz"),
(b"Rar!\x1a\x07", ".rar"),
(b"7z\xbc\xaf\x27\x1c", ".7z"),
(b"\x00\x00\x00", ".mp4"), # ftyp box, further checked below
(b"#!AMR", ".amr"),
]
OFFICE_ZIP_MARKERS = {
b"word/": ".docx",
b"xl/": ".xlsx",
b"ppt/": ".pptx",
}
def _guess_ext_from_bytes(data: bytes) -> str:
"""Guess file extension from file content magic bytes."""
if not data or len(data) < 8:
return ""
for sig, ext in MAGIC_SIGNATURES:
if data[:len(sig)] == sig:
if ext == ".webp" and data[8:12] != b"WEBP":
continue
if ext == ".mp4":
if b"ftyp" not in data[4:12]:
continue
if ext == ".zip":
for marker, office_ext in OFFICE_ZIP_MARKERS.items():
if marker in data[:2000]:
return office_ext
return ".zip"
return ext
return ""
def _decrypt_media(url: str, aeskey: str) -> bytes:
"""
Download and decrypt AES-256-CBC encrypted media from wecom bot.
Returns decrypted bytes.
"""
resp = requests.get(url, timeout=30)
resp.raise_for_status()
encrypted = resp.content
key = base64.b64decode(aeskey + "=" * (-len(aeskey) % 4))
if len(key) != 32:
raise ValueError(f"Invalid AES key length: {len(key)}, expected 32")
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted)
pad_len = decrypted[-1]
if pad_len > 32:
raise ValueError(f"Invalid PKCS7 padding length: {pad_len}")
return decrypted[:-pad_len]
def _get_tmp_dir() -> str:
"""Return the workspace tmp directory (absolute path), creating it if needed."""
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(ws_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
class WecomBotMessage(ChatMessage):
"""Message wrapper for wecom bot (websocket long-connection mode)."""
def __init__(self, msg_body: dict, is_group: bool = False):
super().__init__(msg_body)
self.msg_id = msg_body.get("msgid")
self.create_time = msg_body.get("create_time")
self.is_group = is_group
msg_type = msg_body.get("msgtype")
from_userid = msg_body.get("from", {}).get("userid", "")
chat_id = msg_body.get("chatid", "")
bot_id = msg_body.get("aibotid", "")
if msg_type == "text":
self.ctype = ContextType.TEXT
content = msg_body.get("text", {}).get("content", "")
if is_group:
content = re.sub(r"@\S+\s*", "", content).strip()
self.content = content
elif msg_type == "voice":
self.ctype = ContextType.TEXT
self.content = msg_body.get("voice", {}).get("content", "")
elif msg_type == "image":
self.ctype = ContextType.IMAGE
image_info = msg_body.get("image", {})
image_url = image_info.get("url", "")
aeskey = image_info.get("aeskey", "")
tmp_dir = _get_tmp_dir()
image_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}.png")
try:
data = _decrypt_media(image_url, aeskey)
with open(image_path, "wb") as f:
f.write(data)
self.content = image_path
self.image_path = image_path
logger.info(f"[WecomBot] Image downloaded: {image_path}")
except Exception as e:
logger.error(f"[WecomBot] Failed to download image: {e}")
self.content = "[Image download failed]"
self.image_path = None
elif msg_type == "mixed":
self.ctype = ContextType.TEXT
text_parts = []
image_paths = []
mixed_items = msg_body.get("mixed", {}).get("msg_item", [])
tmp_dir = _get_tmp_dir()
for idx, item in enumerate(mixed_items):
item_type = item.get("msgtype")
if item_type == "text":
txt = item.get("text", {}).get("content", "")
if is_group:
txt = re.sub(r"@\S+\s*", "", txt).strip()
if txt:
text_parts.append(txt)
elif item_type == "image":
img_info = item.get("image", {})
img_url = img_info.get("url", "")
img_aeskey = img_info.get("aeskey", "")
img_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}_{idx}.png")
try:
img_data = _decrypt_media(img_url, img_aeskey)
with open(img_path, "wb") as f:
f.write(img_data)
image_paths.append(img_path)
except Exception as e:
logger.error(f"[WecomBot] Failed to download mixed image: {e}")
content_parts = text_parts[:]
for p in image_paths:
content_parts.append(f"[图片: {p}]")
self.content = "\n".join(content_parts) if content_parts else "[Mixed message]"
elif msg_type == "file":
self.ctype = ContextType.FILE
file_info = msg_body.get("file", {})
file_url = file_info.get("url", "")
aeskey = file_info.get("aeskey", "")
tmp_dir = _get_tmp_dir()
base_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}")
self.content = base_path
def _download_file():
try:
data = _decrypt_media(file_url, aeskey)
ext = _guess_ext_from_bytes(data)
final_path = base_path + ext
with open(final_path, "wb") as f:
f.write(data)
self.content = final_path
logger.info(f"[WecomBot] File downloaded: {final_path}")
except Exception as e:
logger.error(f"[WecomBot] Failed to download file: {e}")
self._prepare_fn = _download_file
elif msg_type == "video":
self.ctype = ContextType.FILE
video_info = msg_body.get("video", {})
video_url = video_info.get("url", "")
aeskey = video_info.get("aeskey", "")
tmp_dir = _get_tmp_dir()
self.content = os.path.join(tmp_dir, f"wecom_{self.msg_id}.mp4")
def _download_video():
try:
data = _decrypt_media(video_url, aeskey)
with open(self.content, "wb") as f:
f.write(data)
logger.info(f"[WecomBot] Video downloaded: {self.content}")
except Exception as e:
logger.error(f"[WecomBot] Failed to download video: {e}")
self._prepare_fn = _download_video
else:
raise NotImplementedError(f"Unsupported message type: {msg_type}")
self.from_user_id = from_userid
self.to_user_id = bot_id
if is_group:
self.other_user_id = chat_id
self.actual_user_id = from_userid
self.actual_user_nickname = from_userid
else:
self.other_user_id = from_userid
self.actual_user_id = from_userid

View File

View File

@@ -0,0 +1,395 @@
"""
Weixin HTTP JSON API client.
Implements the ilink bot protocol:
- getUpdates (long-poll)
- sendMessage
- getUploadUrl
- getConfig
- sendTyping
- QR login (get_bot_qrcode / get_qrcode_status)
CDN media upload with AES-128-ECB encryption.
"""
import base64
import hashlib
import os
import random
import struct
import time
import uuid
import requests
from common.log import logger
DEFAULT_BASE_URL = "https://ilinkai.weixin.qq.com"
CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
DEFAULT_LONG_POLL_TIMEOUT = 35
DEFAULT_API_TIMEOUT = 15
QR_POLL_TIMEOUT = 35
BOT_TYPE = "3"
def _random_wechat_uin() -> str:
val = random.randint(0, 0xFFFFFFFF)
return base64.b64encode(str(val).encode("utf-8")).decode("utf-8")
def _build_headers(token: str = "") -> dict:
headers = {
"Content-Type": "application/json",
"AuthorizationType": "ilink_bot_token",
"X-WECHAT-UIN": _random_wechat_uin(),
}
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _ensure_trailing_slash(url: str) -> str:
return url if url.endswith("/") else url + "/"
class WeixinApi:
"""Stateless HTTP client for the Weixin ilink bot API."""
def __init__(self, base_url: str = DEFAULT_BASE_URL, token: str = "",
cdn_base_url: str = CDN_BASE_URL):
self.base_url = base_url
self.token = token
self.cdn_base_url = cdn_base_url
def _post(self, endpoint: str, body: dict, timeout: int = DEFAULT_API_TIMEOUT) -> dict:
url = _ensure_trailing_slash(self.base_url) + endpoint
headers = _build_headers(self.token)
try:
resp = requests.post(url, json=body, headers=headers, timeout=timeout)
resp.raise_for_status()
return resp.json()
except requests.exceptions.Timeout:
logger.debug(f"[Weixin] API timeout: {endpoint}")
return {"ret": 0, "msgs": []}
except Exception as e:
logger.error(f"[Weixin] API error {endpoint}: {e}")
raise
# ── getUpdates (long-poll) ─────────────────────────────────────────
def get_updates(self, get_updates_buf: str = "", timeout: int = DEFAULT_LONG_POLL_TIMEOUT) -> dict:
return self._post("ilink/bot/getupdates", {
"get_updates_buf": get_updates_buf,
}, timeout=timeout + 5)
# ── sendMessage ────────────────────────────────────────────────────
def send_text(self, to: str, text: str, context_token: str) -> dict:
return self._post("ilink/bot/sendmessage", {
"msg": {
"from_user_id": "",
"to_user_id": to,
"client_id": uuid.uuid4().hex[:16],
"message_type": 2, # BOT
"message_state": 2, # FINISH
"item_list": [{"type": 1, "text_item": {"text": text}}],
"context_token": context_token,
}
})
def send_image_item(self, to: str, context_token: str,
encrypt_query_param: str, aes_key_b64: str,
ciphertext_size: int, text: str = "") -> dict:
items = []
if text:
items.append({"type": 1, "text_item": {"text": text}})
items.append({
"type": 2,
"image_item": {
"media": {
"encrypt_query_param": encrypt_query_param,
"aes_key": aes_key_b64,
"encrypt_type": 1,
},
"mid_size": ciphertext_size,
}
})
return self._send_items(to, context_token, items)
def send_file_item(self, to: str, context_token: str,
encrypt_query_param: str, aes_key_b64: str,
file_name: str, file_size: int, text: str = "") -> dict:
items = []
if text:
items.append({"type": 1, "text_item": {"text": text}})
items.append({
"type": 4,
"file_item": {
"media": {
"encrypt_query_param": encrypt_query_param,
"aes_key": aes_key_b64,
"encrypt_type": 1,
},
"file_name": file_name,
"len": str(file_size),
}
})
return self._send_items(to, context_token, items)
def send_video_item(self, to: str, context_token: str,
encrypt_query_param: str, aes_key_b64: str,
ciphertext_size: int, text: str = "") -> dict:
items = []
if text:
items.append({"type": 1, "text_item": {"text": text}})
items.append({
"type": 5,
"video_item": {
"media": {
"encrypt_query_param": encrypt_query_param,
"aes_key": aes_key_b64,
"encrypt_type": 1,
},
"video_size": ciphertext_size,
}
})
return self._send_items(to, context_token, items)
def _send_items(self, to: str, context_token: str, items: list) -> dict:
return self._post("ilink/bot/sendmessage", {
"msg": {
"from_user_id": "",
"to_user_id": to,
"client_id": uuid.uuid4().hex[:16],
"message_type": 2,
"message_state": 2,
"item_list": items,
"context_token": context_token,
}
})
# ── getUploadUrl ───────────────────────────────────────────────────
def get_upload_url(self, filekey: str, media_type: int, to_user_id: str,
rawsize: int, rawfilemd5: str, filesize: int,
aeskey: str) -> dict:
return self._post("ilink/bot/getuploadurl", {
"filekey": filekey,
"media_type": media_type,
"to_user_id": to_user_id,
"rawsize": rawsize,
"rawfilemd5": rawfilemd5,
"filesize": filesize,
"aeskey": aeskey,
"no_need_thumb": True,
})
# ── getConfig / sendTyping ─────────────────────────────────────────
def get_config(self, user_id: str, context_token: str = "") -> dict:
return self._post("ilink/bot/getconfig", {
"ilink_user_id": user_id,
"context_token": context_token,
}, timeout=10)
def send_typing(self, user_id: str, typing_ticket: str, status: int = 1) -> dict:
return self._post("ilink/bot/sendtyping", {
"ilink_user_id": user_id,
"typing_ticket": typing_ticket,
"status": status,
}, timeout=10)
# ── QR Login ───────────────────────────────────────────────────────
def fetch_qr_code(self) -> dict:
url = _ensure_trailing_slash(self.base_url) + f"ilink/bot/get_bot_qrcode?bot_type={BOT_TYPE}"
resp = requests.get(url, timeout=15)
resp.raise_for_status()
return resp.json()
def poll_qr_status(self, qrcode: str, timeout: int = QR_POLL_TIMEOUT) -> dict:
url = (_ensure_trailing_slash(self.base_url) +
f"ilink/bot/get_qrcode_status?qrcode={requests.utils.quote(qrcode)}")
headers = {"iLink-App-ClientVersion": "1"}
try:
resp = requests.get(url, headers=headers, timeout=timeout)
resp.raise_for_status()
return resp.json()
except requests.exceptions.Timeout:
return {"status": "wait"}
# ── AES-128-ECB helpers ─────────────────────────────────────────────
def _aes_ecb_encrypt(data: bytes, key: bytes) -> bytes:
from Crypto.Cipher import AES
pad_len = 16 - (len(data) % 16)
padded = data + bytes([pad_len] * pad_len)
cipher = AES.new(key, AES.MODE_ECB)
return cipher.encrypt(padded)
def _aes_ecb_decrypt(data: bytes, key: bytes) -> bytes:
from Crypto.Cipher import AES
cipher = AES.new(key, AES.MODE_ECB)
decrypted = cipher.decrypt(data)
pad_len = decrypted[-1]
if pad_len > 16:
return decrypted
return decrypted[:-pad_len]
def _file_md5(file_path: str) -> str:
h = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
h.update(chunk)
return h.hexdigest()
def _md5_bytes(data: bytes) -> str:
return hashlib.md5(data).hexdigest()
def _aes_ecb_padded_size(plaintext_size: int) -> int:
"""PKCS7 padded size for AES-128-ECB."""
return ((plaintext_size + 1 + 15) // 16) * 16
UPLOAD_MAX_RETRIES = 3
def upload_media_to_cdn(api: WeixinApi, file_path: str, to_user_id: str,
media_type: int) -> dict:
"""
Upload a local file to the Weixin CDN (matching official plugin protocol).
Args:
api: WeixinApi instance
file_path: local file path
to_user_id: target user id
media_type: 1=IMAGE, 2=VIDEO, 3=FILE
Returns:
dict with keys: encrypt_query_param, aes_key_b64, ciphertext_size, raw_size
"""
aes_key = os.urandom(16)
aes_key_hex = aes_key.hex()
filekey = uuid.uuid4().hex
with open(file_path, "rb") as f:
raw_data = f.read()
raw_size = len(raw_data)
raw_md5 = _md5_bytes(raw_data)
cipher_size = _aes_ecb_padded_size(raw_size)
encrypted = _aes_ecb_encrypt(raw_data, aes_key)
from urllib.parse import quote
download_param = None
last_error = None
for attempt in range(1, UPLOAD_MAX_RETRIES + 1):
try:
if attempt > 1:
filekey = uuid.uuid4().hex
resp = api.get_upload_url(
filekey=filekey,
media_type=media_type,
to_user_id=to_user_id,
rawsize=raw_size,
rawfilemd5=raw_md5,
filesize=cipher_size,
aeskey=aes_key_hex,
)
upload_param = resp.get("upload_param", "")
if not upload_param:
raise RuntimeError(f"[Weixin] getUploadUrl returned no upload_param: {resp}")
cdn_url = (f"{api.cdn_base_url}/upload"
f"?encrypted_query_param={quote(upload_param)}"
f"&filekey={quote(filekey)}")
cdn_resp = requests.post(cdn_url, data=encrypted, headers={
"Content-Type": "application/octet-stream",
"Content-Length": str(len(encrypted)),
}, timeout=120)
if 400 <= cdn_resp.status_code < 500:
err_msg = cdn_resp.headers.get("x-error-message", cdn_resp.text[:200])
raise RuntimeError(f"CDN client error {cdn_resp.status_code}: {err_msg}")
cdn_resp.raise_for_status()
download_param = cdn_resp.headers.get("x-encrypted-param", "")
if not download_param:
raise RuntimeError("CDN response missing x-encrypted-param header")
logger.debug(f"[Weixin] CDN upload success attempt={attempt} filekey={filekey}")
break
except Exception as e:
last_error = e
if "client error" in str(e):
raise
if attempt < UPLOAD_MAX_RETRIES:
backoff = 2 ** attempt
logger.warning(f"[Weixin] CDN upload attempt {attempt} failed, retrying in {backoff}s: {e}")
time.sleep(backoff)
else:
logger.error(f"[Weixin] CDN upload failed after {UPLOAD_MAX_RETRIES} attempts: {e}")
if not download_param:
raise last_error or RuntimeError("CDN upload failed")
aes_key_b64 = base64.b64encode(aes_key_hex.encode("utf-8")).decode("utf-8")
return {
"encrypt_query_param": download_param,
"aes_key_b64": aes_key_b64,
"ciphertext_size": cipher_size,
"raw_size": raw_size,
}
def download_media_from_cdn(cdn_base_url: str, encrypt_query_param: str,
aes_key: str, save_path: str) -> str:
"""
Download and decrypt a media file from Weixin CDN.
Args:
cdn_base_url: CDN base URL
encrypt_query_param: encrypted query parameter from message
aes_key: hex or base64 encoded AES key
save_path: path to save decrypted file
Returns:
save_path on success
"""
from urllib.parse import quote
url = f"{cdn_base_url}/download?encrypted_query_param={quote(encrypt_query_param)}"
resp = requests.get(url, timeout=60)
resp.raise_for_status()
# Determine key format:
# 1) 32-char hex string → 16 raw bytes
# 2) base64 string → decode → if 32 bytes, treat as hex-encoded → 16 raw bytes
# 3) base64 string → decode → 16 raw bytes directly
try:
key_bytes = bytes.fromhex(aes_key)
if len(key_bytes) != 16:
raise ValueError()
except (ValueError, TypeError):
decoded = base64.b64decode(aes_key)
if len(decoded) == 32:
try:
key_bytes = bytes.fromhex(decoded.decode("ascii"))
except (ValueError, UnicodeDecodeError):
raise ValueError(f"Invalid AES key: 32 bytes but not valid hex")
elif len(decoded) == 16:
key_bytes = decoded
else:
raise ValueError(f"Invalid AES key length after base64 decode: {len(decoded)}")
decrypted = _aes_ecb_decrypt(resp.content, key_bytes)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as f:
f.write(decrypted)
return save_path

View File

@@ -0,0 +1,629 @@
"""
Weixin channel implementation.
Uses HTTP long-poll (getUpdates) to receive messages and sendMessage to reply.
Login via QR code scan through the ilink bot API.
"""
import json
import os
import threading
import time
import uuid
import requests
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.weixin.weixin_api import (
WeixinApi, upload_media_to_cdn,
DEFAULT_BASE_URL, CDN_BASE_URL,
)
from channel.weixin.weixin_message import WeixinMessage
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from config import conf
MAX_CONSECUTIVE_FAILURES = 3
BACKOFF_DELAY = 30
RETRY_DELAY = 2
SESSION_EXPIRED_ERRCODE = -14
TEXT_CHUNK_LIMIT = 4000
QR_LOGIN_TIMEOUT_S = 480
QR_MAX_REFRESHES = 10
def _load_credentials(cred_path: str) -> dict:
"""Load saved credentials from JSON file."""
try:
if os.path.exists(cred_path):
with open(cred_path, "r") as f:
return json.load(f)
except Exception as e:
logger.warning(f"[Weixin] Failed to load credentials: {e}")
return {}
def _save_credentials(cred_path: str, data: dict):
"""Save credentials to JSON file."""
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
json.dump(data, f, indent=2)
try:
os.chmod(cred_path, 0o600)
except Exception:
pass
@singleton
class WeixinChannel(ChatChannel):
LOGIN_STATUS_IDLE = "idle"
LOGIN_STATUS_WAITING = "waiting_scan"
LOGIN_STATUS_SCANNED = "scanned"
LOGIN_STATUS_OK = "logged_in"
def __init__(self):
super().__init__()
self.api = None
self._stop_event = threading.Event()
self._poll_thread = None
self._context_tokens = {} # user_id -> context_token
self._received_msgs = ExpiredDict(60 * 60 * 7.1)
self._get_updates_buf = ""
self._credentials_path = ""
self.login_status = self.LOGIN_STATUS_IDLE
self._current_qr_url = ""
conf()["single_chat_prefix"] = [""]
# ── Lifecycle ──────────────────────────────────────────────────────
def startup(self):
self._stop_event.clear()
base_url = conf().get("weixin_base_url", DEFAULT_BASE_URL)
cdn_base_url = conf().get("weixin_cdn_base_url", CDN_BASE_URL)
token = conf().get("weixin_token", "")
self._credentials_path = os.path.expanduser(
conf().get("weixin_credentials_path", "~/.weixin_cow_credentials.json")
)
if not token:
creds = _load_credentials(self._credentials_path)
token = creds.get("token", "")
if creds.get("base_url"):
base_url = creds["base_url"]
if not token:
token, base_url = self._login_with_retry(base_url)
if not token:
return
self.api = WeixinApi(base_url=base_url, token=token, cdn_base_url=cdn_base_url)
self.login_status = self.LOGIN_STATUS_OK
logger.info(f"[Weixin] 微信通道已启动,凭证保存在 {self._credentials_path}"
f"如需重新扫码登录请删除该文件后重启")
self.report_startup_success()
self._poll_loop()
def _login_with_retry(self, base_url: str) -> tuple:
"""Attempt QR login, then wait for stop if failed.
Returns (token, base_url) on success, or ("", "") if stopped."""
logger.info("[Weixin] No token found, starting QR login...")
self.login_status = self.LOGIN_STATUS_WAITING
login_result = self._qr_login(base_url)
if login_result:
return login_result["token"], login_result.get("base_url", base_url)
self.login_status = self.LOGIN_STATUS_IDLE
if not self._stop_event.is_set():
logger.info("[Weixin] QR login timed out, waiting for stop or reconnect...")
print(" 二维码登录超时,请通过控制台重新接入\n")
self._stop_event.wait()
logger.info("[Weixin] Login cancelled by stop event")
return "", ""
def stop(self):
logger.info("[Weixin] stop() called")
self._stop_event.set()
def _relogin(self) -> bool:
"""Re-login after session expiry. Returns True on success."""
base_url = self.api.base_url if self.api else DEFAULT_BASE_URL
if os.path.exists(self._credentials_path):
try:
os.remove(self._credentials_path)
except Exception:
pass
self.login_status = self.LOGIN_STATUS_WAITING
result = self._qr_login(base_url)
if not result:
self.login_status = self.LOGIN_STATUS_IDLE
return False
self.api = WeixinApi(
base_url=result.get("base_url", base_url),
token=result["token"],
cdn_base_url=self.api.cdn_base_url if self.api else CDN_BASE_URL,
)
self.login_status = self.LOGIN_STATUS_OK
self._context_tokens.clear()
return True
# ── QR Login ───────────────────────────────────────────────────────
@staticmethod
def _print_qr(qrcode_url: str):
"""Print QR code to terminal for scanning."""
print("\n" + "=" * 60)
print(" 请使用微信扫描二维码登录 (二维码约2分钟后过期)")
print("=" * 60)
try:
import qrcode as qr_lib
qr = qr_lib.QRCode(error_correction=qr_lib.constants.ERROR_CORRECT_L, box_size=1, border=1)
qr.add_data(qrcode_url)
qr.make(fit=True)
qr.print_ascii(invert=True)
except ImportError:
print(f"\n 二维码链接: {qrcode_url}")
print(" (安装 'qrcode' 包可在终端显示二维码)\n")
def _notify_cloud_qrcode(self, qrcode_url: str):
"""Send QR code URL to cloud console when running in cloud mode."""
if not self.cloud_mode:
return
try:
from common import cloud_client
client = getattr(cloud_client, "chat_client", None)
if client and getattr(client, "client_id", None):
client.send_channel_qrcode("weixin", qrcode_url)
except Exception as e:
logger.warning(f"[Weixin] Failed to notify cloud QR code: {e}")
def _notify_cloud_connected(self):
"""Send connected status to cloud console when login succeeds."""
if not self.cloud_mode:
return
try:
from common import cloud_client
client = getattr(cloud_client, "chat_client", None)
if client and getattr(client, "client_id", None):
client.send_channel_status("weixin", "connected")
except Exception as e:
logger.warning(f"[Weixin] Failed to notify cloud connected: {e}")
def _qr_login(self, base_url: str) -> dict:
"""Perform interactive QR code login. Returns dict with token/base_url or empty dict."""
api = WeixinApi(base_url=base_url)
try:
qr_resp = api.fetch_qr_code()
except Exception as e:
logger.error(f"[Weixin] Failed to fetch QR code: {e}")
return {}
qrcode = qr_resp.get("qrcode", "")
qrcode_url = qr_resp.get("qrcode_img_content", "")
if not qrcode:
logger.error("[Weixin] No QR code returned from server")
return {}
self._current_qr_url = qrcode_url
logger.info(f"[Weixin] 微信二维码链接: {qrcode_url}")
self._print_qr(qrcode_url)
self._notify_cloud_qrcode(qrcode_url)
print(" 等待扫码...\n")
scanned_printed = False
refresh_count = 0
deadline = time.time() + QR_LOGIN_TIMEOUT_S
while not self._stop_event.is_set():
if time.time() >= deadline:
logger.warning(f"[Weixin] QR login timed out after {QR_LOGIN_TIMEOUT_S}s")
print(f"\n 二维码登录超时({QR_LOGIN_TIMEOUT_S}s请重启后重试")
break
try:
status_resp = api.poll_qr_status(qrcode)
except Exception as e:
logger.error(f"[Weixin] QR status poll error: {e}")
return {}
status = status_resp.get("status", "wait")
if status == "wait":
pass
elif status == "scaned":
self.login_status = self.LOGIN_STATUS_SCANNED
if not scanned_printed:
print(" 已扫码,请在手机上确认...")
scanned_printed = True
elif status == "expired":
refresh_count += 1
if refresh_count >= QR_MAX_REFRESHES:
logger.warning(f"[Weixin] QR code refreshed {QR_MAX_REFRESHES} times, giving up")
print(f"\n 二维码已刷新 {QR_MAX_REFRESHES} 次仍未扫码,请重启后重试")
break
print(f" 二维码已过期,正在刷新({refresh_count}/{QR_MAX_REFRESHES}...")
try:
qr_resp = api.fetch_qr_code()
qrcode = qr_resp.get("qrcode", "")
qrcode_url = qr_resp.get("qrcode_img_content", "")
scanned_printed = False
self._current_qr_url = qrcode_url
logger.info(f"[Weixin] 微信二维码链接 ({refresh_count}/{QR_MAX_REFRESHES}): {qrcode_url}")
self._print_qr(qrcode_url)
self._notify_cloud_qrcode(qrcode_url)
except Exception as e:
logger.error(f"[Weixin] QR refresh failed: {e}")
return {}
elif status == "confirmed":
bot_token = status_resp.get("bot_token", "")
bot_id = status_resp.get("ilink_bot_id", "")
result_base_url = status_resp.get("baseurl", base_url)
user_id = status_resp.get("ilink_user_id", "")
if not bot_token or not bot_id:
logger.error("[Weixin] Login confirmed but missing token/bot_id")
return {}
self._current_qr_url = ""
print(f"\n ✅ 微信登录成功bot_id={bot_id}")
logger.info(f"[Weixin] Login confirmed: bot_id={bot_id}")
self._notify_cloud_connected()
creds = {
"token": bot_token,
"base_url": result_base_url,
"bot_id": bot_id,
"user_id": user_id,
}
_save_credentials(self._credentials_path, creds)
logger.info(f"[Weixin] Credentials saved to {self._credentials_path}")
return {"token": bot_token, "base_url": result_base_url}
self._stop_event.wait(1)
self._current_qr_url = ""
if self._stop_event.is_set():
logger.info("[Weixin] QR login cancelled by stop event")
return {}
# ── Long-poll loop ─────────────────────────────────────────────────
def _poll_loop(self):
"""Main long-poll loop: getUpdates -> parse -> produce."""
logger.info("[Weixin] Starting long-poll loop")
consecutive_failures = 0
while not self._stop_event.is_set():
try:
resp = self.api.get_updates(self._get_updates_buf)
ret = resp.get("ret", 0)
errcode = resp.get("errcode", 0)
is_error = (ret != 0) or (errcode != 0)
if is_error:
if errcode == SESSION_EXPIRED_ERRCODE or ret == SESSION_EXPIRED_ERRCODE:
logger.error("[Weixin] Session expired (errcode -14), starting re-login...")
if self._relogin():
logger.info("[Weixin] Re-login successful, resuming long-poll")
self._get_updates_buf = ""
consecutive_failures = 0
continue
else:
logger.error("[Weixin] Re-login failed, will retry in 5 minutes")
self._stop_event.wait(300)
continue
consecutive_failures += 1
errmsg = resp.get("errmsg", "")
logger.error(f"[Weixin] getUpdates error: ret={ret} errcode={errcode} "
f"errmsg={errmsg} ({consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
consecutive_failures = 0
self._stop_event.wait(BACKOFF_DELAY)
else:
self._stop_event.wait(RETRY_DELAY)
continue
consecutive_failures = 0
# Update sync cursor
new_buf = resp.get("get_updates_buf", "")
if new_buf:
self._get_updates_buf = new_buf
# Process messages
msgs = resp.get("msgs", [])
for raw_msg in msgs:
try:
self._process_message(raw_msg)
except Exception as e:
logger.error(f"[Weixin] Failed to process message: {e}", exc_info=True)
except Exception as e:
if self._stop_event.is_set():
break
consecutive_failures += 1
logger.error(f"[Weixin] getUpdates exception: {e} "
f"({consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
consecutive_failures = 0
self._stop_event.wait(BACKOFF_DELAY)
else:
self._stop_event.wait(RETRY_DELAY)
logger.info("[Weixin] Long-poll loop ended")
def _process_message(self, raw_msg: dict):
"""Parse a single inbound message and produce to the handling queue."""
msg_type = raw_msg.get("message_type", 0)
if msg_type != 1: # Only process USER messages (type=1)
return
msg_id = str(raw_msg.get("message_id", raw_msg.get("seq", "")))
if self._received_msgs.get(msg_id):
return
self._received_msgs[msg_id] = True
from_user = raw_msg.get("from_user_id", "")
context_token = raw_msg.get("context_token", "")
if context_token and from_user:
self._context_tokens[from_user] = context_token
cdn_base_url = self.api.cdn_base_url if self.api else CDN_BASE_URL
try:
wx_msg = WeixinMessage(raw_msg, cdn_base_url=cdn_base_url)
except Exception as e:
logger.error(f"[Weixin] Failed to parse WeixinMessage: {e}", exc_info=True)
return
logger.info(f"[Weixin] Received: from={from_user} ctype={wx_msg.ctype} "
f"content={str(wx_msg.content)[:50]}")
# File cache logic
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
session_id = from_user
if wx_msg.ctype == ContextType.IMAGE:
if hasattr(wx_msg, "image_path") and wx_msg.image_path:
file_cache.add(session_id, wx_msg.image_path, file_type="image")
logger.info(f"[Weixin] Image cached for session {session_id}")
return
if wx_msg.ctype == ContextType.FILE:
wx_msg.prepare()
file_cache.add(session_id, wx_msg.content, file_type="file")
logger.info(f"[Weixin] File cached for session {session_id}: {wx_msg.content}")
return
if wx_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
refs = []
for fi in cached_files:
ftype, fpath = fi["type"], fi["path"]
if ftype == "image":
refs.append(f"[图片: {fpath}]")
elif ftype == "video":
refs.append(f"[视频: {fpath}]")
else:
refs.append(f"[文件: {fpath}]")
wx_msg.content = wx_msg.content + "\n" + "\n".join(refs)
file_cache.clear(session_id)
context = self._compose_context(
wx_msg.ctype,
wx_msg.content,
isgroup=False,
msg=wx_msg,
no_need_at=True,
)
if context:
self.produce(context)
# ── _compose_context ───────────────────────────────────────────────
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
return context
# ── Send reply ─────────────────────────────────────────────────────
def send(self, reply: Reply, context: Context):
receiver = context.get("receiver", "")
msg = context.get("msg")
context_token = self._get_context_token(receiver, msg)
if not context_token:
logger.error(f"[Weixin] No context_token for receiver={receiver}, cannot send")
return
if reply.type == ReplyType.TEXT:
self._send_text(reply.content, receiver, context_token)
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
self._send_image(reply.content, receiver, context_token)
elif reply.type == ReplyType.FILE:
self._send_file(reply.content, receiver, context_token)
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
self._send_video(reply.content, receiver, context_token)
else:
logger.warning(f"[Weixin] Unsupported reply type: {reply.type}, fallback to text")
self._send_text(str(reply.content), receiver, context_token)
def _get_context_token(self, receiver: str, msg=None) -> str:
"""Get the context_token for a receiver, required for all sends."""
if msg and hasattr(msg, "context_token") and msg.context_token:
return msg.context_token
return self._context_tokens.get(receiver, "")
def _send_text(self, text: str, receiver: str, context_token: str):
if len(text) <= TEXT_CHUNK_LIMIT:
try:
self.api.send_text(receiver, text, context_token)
logger.debug(f"[Weixin] Text sent to {receiver}, len={len(text)}")
except Exception as e:
logger.error(f"[Weixin] Failed to send text: {e}")
return
chunks = self._split_text(text, TEXT_CHUNK_LIMIT)
for i, chunk in enumerate(chunks):
try:
self.api.send_text(receiver, chunk, context_token)
logger.debug(f"[Weixin] Text chunk {i+1}/{len(chunks)} sent to {receiver}, len={len(chunk)}")
except Exception as e:
logger.error(f"[Weixin] Failed to send text chunk {i+1}/{len(chunks)}: {e}")
break
if i < len(chunks) - 1:
time.sleep(0.5)
@staticmethod
def _split_text(text: str, limit: int) -> list:
"""Split text into chunks, preferring to break at paragraph or line boundaries."""
if len(text) <= limit:
return [text]
chunks = []
while text:
if len(text) <= limit:
chunks.append(text)
break
cut = text.rfind("\n\n", 0, limit)
if cut <= 0:
cut = text.rfind("\n", 0, limit)
if cut <= 0:
cut = limit
chunks.append(text[:cut])
text = text[cut:].lstrip("\n")
return chunks
def _send_image(self, img_path_or_url: str, receiver: str, context_token: str):
local_path = self._resolve_media_path(img_path_or_url)
if not local_path:
self._send_text("[Image send failed: file not found]", receiver, context_token)
return
try:
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=1)
self.api.send_image_item(
to=receiver,
context_token=context_token,
encrypt_query_param=result["encrypt_query_param"],
aes_key_b64=result["aes_key_b64"],
ciphertext_size=result["ciphertext_size"],
)
logger.info(f"[Weixin] Image sent to {receiver}")
except Exception as e:
logger.error(f"[Weixin] Image send failed: {e}")
self._send_text("[Image send failed]", receiver, context_token)
def _send_file(self, file_path_or_url: str, receiver: str, context_token: str):
local_path = self._resolve_media_path(file_path_or_url)
if not local_path:
self._send_text("[File send failed: file not found]", receiver, context_token)
return
try:
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=3)
self.api.send_file_item(
to=receiver,
context_token=context_token,
encrypt_query_param=result["encrypt_query_param"],
aes_key_b64=result["aes_key_b64"],
file_name=os.path.basename(local_path),
file_size=result["raw_size"],
)
logger.info(f"[Weixin] File sent to {receiver}")
except Exception as e:
logger.error(f"[Weixin] File send failed: {e}")
self._send_text("[File send failed]", receiver, context_token)
def _send_video(self, video_path_or_url: str, receiver: str, context_token: str):
local_path = self._resolve_media_path(video_path_or_url)
if not local_path:
self._send_text("[Video send failed: file not found]", receiver, context_token)
return
try:
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=2)
self.api.send_video_item(
to=receiver,
context_token=context_token,
encrypt_query_param=result["encrypt_query_param"],
aes_key_b64=result["aes_key_b64"],
ciphertext_size=result["ciphertext_size"],
)
logger.info(f"[Weixin] Video sent to {receiver}")
except Exception as e:
logger.error(f"[Weixin] Video send failed: {e}")
self._send_text("[Video send failed]", receiver, context_token)
@staticmethod
def _resolve_media_path(path_or_url: str) -> str:
"""Resolve a file path or URL to a local file path. Downloads if needed."""
if not path_or_url:
return ""
local_path = path_or_url
if local_path.startswith("file://"):
local_path = local_path[7:]
if local_path.startswith(("http://", "https://")):
try:
resp = requests.get(local_path, timeout=60)
resp.raise_for_status()
ct = resp.headers.get("Content-Type", "")
ext = ".bin"
if "jpeg" in ct or "jpg" in ct:
ext = ".jpg"
elif "png" in ct:
ext = ".png"
elif "gif" in ct:
ext = ".gif"
elif "webp" in ct:
ext = ".webp"
elif "mp4" in ct:
ext = ".mp4"
elif "pdf" in ct:
ext = ".pdf"
tmp_path = f"/tmp/wx_media_{uuid.uuid4().hex[:8]}{ext}"
with open(tmp_path, "wb") as f:
f.write(resp.content)
return tmp_path
except Exception as e:
logger.error(f"[Weixin] Failed to download media: {e}")
return ""
if os.path.exists(local_path):
return local_path
logger.warning(f"[Weixin] Media file not found: {local_path}")
return ""

View File

@@ -0,0 +1,204 @@
"""
Weixin ChatMessage implementation.
Parses WeixinMessage from the getUpdates API into the unified ChatMessage format.
"""
import os
import uuid
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from channel.weixin.weixin_api import download_media_from_cdn, CDN_BASE_URL
from common.log import logger
from common.utils import expand_path
from config import conf
# MessageItemType constants from the Weixin protocol
ITEM_TEXT = 1
ITEM_IMAGE = 2
ITEM_VOICE = 3
ITEM_FILE = 4
ITEM_VIDEO = 5
def _get_tmp_dir() -> str:
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(ws_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
class WeixinMessage(ChatMessage):
"""Message wrapper for Weixin channel."""
def __init__(self, msg: dict, cdn_base_url: str = CDN_BASE_URL):
super().__init__(msg)
self.msg_id = str(msg.get("message_id", msg.get("seq", uuid.uuid4().hex[:8])))
self.create_time = msg.get("create_time_ms", 0)
self.context_token = msg.get("context_token", "")
self.is_group = False # Weixin plugin only supports direct chat
self.is_at = False
from_user_id = msg.get("from_user_id", "")
to_user_id = msg.get("to_user_id", "")
self.from_user_id = from_user_id
self.from_user_nickname = from_user_id
self.to_user_id = to_user_id
self.to_user_nickname = to_user_id
self.other_user_id = from_user_id
self.other_user_nickname = from_user_id
self.actual_user_id = from_user_id
self.actual_user_nickname = from_user_id
item_list = msg.get("item_list", [])
# Parse items: find text and media
text_body = ""
media_item = None
media_type = None
ref_text = ""
for item in item_list:
itype = item.get("type", 0)
if itype == ITEM_TEXT:
text_item = item.get("text_item", {})
text_body = text_item.get("text", "")
ref = item.get("ref_msg")
if ref:
ref_title = ref.get("title", "")
ref_mi = ref.get("message_item", {})
ref_body = ""
if ref_mi.get("type") == ITEM_TEXT:
ref_body = ref_mi.get("text_item", {}).get("text", "")
if ref_title or ref_body:
parts = [p for p in [ref_title, ref_body] if p]
ref_text = f"[引用: {' | '.join(parts)}]\n"
# If ref is a media item, treat it as the media to download
if ref_mi.get("type") in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE):
media_item = ref_mi
media_type = ref_mi.get("type")
elif itype == ITEM_VOICE:
voice_item = item.get("voice_item", {})
voice_text = voice_item.get("text", "")
if voice_text:
text_body = voice_text
else:
# Voice without transcription - download the audio
media_item = item
media_type = ITEM_VOICE
elif itype in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE):
if not media_item:
media_item = item
media_type = itype
# Determine ctype and content
if media_item and not text_body:
self._setup_media(media_item, media_type, cdn_base_url)
elif media_item and text_body:
# Text + media: download media, attach as file ref in text
self.ctype = ContextType.TEXT
media_path = self._download_media(media_item, media_type, cdn_base_url)
if media_path:
if media_type == ITEM_IMAGE:
text_body += f"\n[图片: {media_path}]"
elif media_type == ITEM_VIDEO:
text_body += f"\n[视频: {media_path}]"
else:
text_body += f"\n[文件: {media_path}]"
self.content = ref_text + text_body
else:
self.ctype = ContextType.TEXT
self.content = ref_text + text_body
def _setup_media(self, item: dict, media_type: int, cdn_base_url: str):
"""Set up message as a media type, with lazy download via _prepare_fn."""
if media_type == ITEM_IMAGE:
self.ctype = ContextType.IMAGE
image_path = self._download_media(item, ITEM_IMAGE, cdn_base_url)
if image_path:
self.content = image_path
self.image_path = image_path
else:
self.ctype = ContextType.TEXT
self.content = "[Image download failed]"
elif media_type == ITEM_VIDEO:
self.ctype = ContextType.FILE
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.mp4")
self.content = save_path
def _download():
path = self._download_media(item, ITEM_VIDEO, cdn_base_url)
if path:
self.content = path
self._prepare_fn = _download
elif media_type == ITEM_FILE:
self.ctype = ContextType.FILE
file_name = item.get("file_item", {}).get("file_name", f"wx_{self.msg_id}")
save_path = os.path.join(_get_tmp_dir(), file_name)
self.content = save_path
def _download():
path = self._download_media(item, ITEM_FILE, cdn_base_url)
if path:
self.content = path
self._prepare_fn = _download
elif media_type == ITEM_VOICE:
self.ctype = ContextType.VOICE
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.silk")
self.content = save_path
def _download():
path = self._download_media(item, ITEM_VOICE, cdn_base_url)
if path:
self.content = path
self._prepare_fn = _download
def _download_media(self, item: dict, media_type: int, cdn_base_url: str) -> str:
"""Download media from CDN, returns local file path or empty string."""
type_key_map = {
ITEM_IMAGE: "image_item",
ITEM_VIDEO: "video_item",
ITEM_FILE: "file_item",
ITEM_VOICE: "voice_item",
}
key = type_key_map.get(media_type, "")
info = item.get(key, {})
media = info.get("media", {})
encrypt_param = media.get("encrypt_query_param", "")
# aes_key can be in image_item.aeskey (hex) or media.aes_key (b64)
aes_key = info.get("aeskey", "") or media.get("aes_key", "")
if not encrypt_param or not aes_key:
logger.warning(f"[Weixin] Missing CDN params for media download (type={media_type})")
return ""
if media_type == ITEM_FILE:
original_name = info.get("file_name", "")
if original_name:
save_path = os.path.join(_get_tmp_dir(), original_name)
else:
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.bin")
else:
ext_map = {ITEM_IMAGE: ".jpg", ITEM_VIDEO: ".mp4", ITEM_VOICE: ".silk"}
ext = ext_map.get(media_type, "")
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}{ext}")
try:
download_media_from_cdn(cdn_base_url, encrypt_param, aes_key, save_path)
logger.info(f"[Weixin] Media downloaded: {save_path}")
return save_path
except Exception as e:
logger.error(f"[Weixin] Media download failed: {e}")
return ""

View File

@@ -1,17 +0,0 @@
import os
import time
os.environ['ntwork_LOG'] = "ERROR"
import ntwork
wework = ntwork.WeWork()
def forever():
try:
while True:
time.sleep(0.1)
except KeyboardInterrupt:
ntwork.exit_()
os._exit(0)

View File

@@ -1,326 +0,0 @@
import io
import os
import random
import tempfile
import threading
os.environ['ntwork_LOG'] = "ERROR"
import ntwork
import requests
import uuid
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wework.wework_message import *
from channel.wework.wework_message import WeworkMessage
from common.singleton import singleton
from common.log import logger
from common.time_check import time_checker
from common.utils import compress_imgfile, fsize
from config import conf
from channel.wework.run import wework
from channel.wework import run
from PIL import Image
def get_wxid_by_name(room_members, group_wxid, name):
if group_wxid in room_members:
for member in room_members[group_wxid]['member_list']:
if member['room_nickname'] == name or member['username'] == name:
return member['user_id']
return None # 如果没有找到对应的group_wxid或name则返回None
def download_and_compress_image(url, filename, quality=30):
# 确定保存图片的目录
directory = os.path.join(os.getcwd(), "tmp")
# 如果目录不存在,则创建目录
if not os.path.exists(directory):
os.makedirs(directory)
# 下载图片
pic_res = requests.get(url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
# 检查图片大小并可能进行压缩
sz = fsize(image_storage)
if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
# 将内存缓冲区的指针重置到起始位置
image_storage.seek(0)
# 读取并保存图片
image = Image.open(image_storage)
image_path = os.path.join(directory, f"{filename}.png")
image.save(image_path, "png")
return image_path
def download_video(url, filename):
# 确定保存视频的目录
directory = os.path.join(os.getcwd(), "tmp")
# 如果目录不存在,则创建目录
if not os.path.exists(directory):
os.makedirs(directory)
# 下载视频
response = requests.get(url, stream=True)
total_size = 0
video_path = os.path.join(directory, f"{filename}.mp4")
with open(video_path, 'wb') as f:
for block in response.iter_content(1024):
total_size += len(block)
# 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
if total_size > 30 * 1024 * 1024:
logger.info("[WX] Video is larger than 30MB, skipping...")
return None
f.write(block)
return video_path
def create_message(wework_instance, message, is_group):
logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
logger.debug(f"cmsg:{cmsg}")
return cmsg
def handle_message(cmsg, is_group):
logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
if is_group:
WeworkChannel().handle_group(cmsg)
else:
WeworkChannel().handle_single(cmsg)
logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
def _check(func):
def wrapper(self, cmsg: ChatMessage):
msgId = cmsg.msg_id
create_time = cmsg.create_time # 消息时间戳
if create_time is None:
return func(self, cmsg)
if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
return func(self, cmsg)
return wrapper
@wework.msg_register(
[ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
def all_msg_handler(wework_instance: ntwork.WeWork, message):
logger.debug(f"收到消息: {message}")
if 'data' in message:
# 首先查找conversation_id如果没有找到则查找room_conversation_id
conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
if conversation_id is not None:
is_group = "R:" in conversation_id
try:
cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
except NotImplementedError as e:
logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
return None
delay = random.randint(1, 2)
timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
timer.start()
else:
logger.debug("消息数据中无 conversation_id")
return None
return None
def accept_friend_with_retries(wework_instance, user_id, corp_id):
result = wework_instance.accept_friend(user_id, corp_id)
logger.debug(f'result:{result}')
# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
# def friend(wework_instance: ntwork.WeWork, message):
# data = message["data"]
# user_id = data["user_id"]
# corp_id = data["corp_id"]
# logger.info(f"接收到好友请求,消息内容:{data}")
# delay = random.randint(1, 180)
# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
#
# return None
def get_with_retry(get_func, max_retries=5, delay=5):
retries = 0
result = None
while retries < max_retries:
result = get_func()
if result:
break
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
retries += 1
time.sleep(delay) # 等待一段时间后重试
return result
@singleton
class WeworkChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []
def __init__(self):
super().__init__()
def startup(self):
smart = conf().get("wework_smart", True)
wework.open(smart)
logger.info("等待登录······")
wework.wait_login()
login_info = wework.get_login_info()
self.user_id = login_info['user_id']
self.name = login_info['nickname']
logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
logger.info("静默延迟60s等待客户端刷新数据请勿进行任何操作······")
time.sleep(60)
contacts = get_with_retry(wework.get_external_contacts)
rooms = get_with_retry(wework.get_rooms)
directory = os.path.join(os.getcwd(), "tmp")
if not contacts or not rooms:
logger.error("获取contacts或rooms失败程序退出")
ntwork.exit_()
os.exit(0)
if not os.path.exists(directory):
os.makedirs(directory)
# 将contacts保存到json文件中
with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
json.dump(contacts, f, ensure_ascii=False, indent=4)
with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
json.dump(rooms, f, ensure_ascii=False, indent=4)
# 创建一个空字典来保存结果
result = {}
# 遍历列表中的每个字典
for room in rooms['room_list']:
# 获取聊天室ID
room_wxid = room['conversation_id']
# 获取聊天室成员
room_members = wework.get_room_members(room_wxid)
# 将聊天室成员保存到结果字典中
result[room_wxid] = room_members
# 将结果保存到json文件中
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=4)
logger.info("wework程序初始化完成········")
run.forever()
@time_checker
@_check
def handle_single(self, cmsg: ChatMessage):
if cmsg.from_user_id == cmsg.to_user_id:
# ignore self reply
return
if cmsg.ctype == ContextType.VOICE:
if not conf().get("speech_recognition"):
return
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)
@time_checker
@_check
def handle_group(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE:
if not conf().get("speech_recognition"):
return
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
pass
else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
if context:
self.produce(context)
# 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
logger.debug(f"context: {context}")
receiver = context["receiver"]
actual_user_id = context["msg"].actual_user_id
if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
match = re.search(r"^@(.*?)\n", reply.content)
logger.debug(f"match: {match}")
if match:
new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
at_list = [actual_user_id]
logger.debug(f"new_content: {new_content}")
wework.send_room_at_msg(receiver, new_content, at_list)
else:
wework.send_text(receiver, reply.content)
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
wework.send_text(receiver, reply.content)
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
# Read data from image_storage
data = image_storage.read()
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp:
temp_path = temp.name
temp.write(data)
# Send the image
wework.send_image(receiver, temp_path)
logger.info("[WX] sendImage, receiver={}".format(receiver))
# Remove the temporary file
os.remove(temp_path)
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
filename = str(uuid.uuid4())
# 调用你的函数,下载图片并保存为本地文件
image_path = download_and_compress_image(img_url, filename)
wework.send_image(receiver, file_path=image_path)
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.VIDEO_URL:
video_url = reply.content
filename = str(uuid.uuid4())
video_path = download_video(video_url, filename)
if video_path is None:
# 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
wework.send_text(receiver, "抱歉,视频太大了!!!")
else:
wework.send_video(receiver, video_path)
logger.info("[WX] sendVideo, receiver={}".format(receiver))
elif reply.type == ReplyType.VOICE:
current_dir = os.getcwd()
voice_file = reply.content.split("/")[-1]
reply.content = os.path.join(current_dir, "tmp", voice_file)
wework.send_file(receiver, reply.content)
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))

View File

@@ -1,227 +0,0 @@
import datetime
import json
import os
import re
import time
import pilk
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from ntwork.const import send_type
def get_with_retry(get_func, max_retries=5, delay=5):
retries = 0
result = None
while retries < max_retries:
result = get_func()
if result:
break
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
retries += 1
time.sleep(delay) # 等待一段时间后重试
return result
def get_room_info(wework, conversation_id):
logger.debug(f"传入的 conversation_id: {conversation_id}")
rooms = wework.get_rooms()
if not rooms or 'room_list' not in rooms:
logger.error(f"获取群聊信息失败: {rooms}")
return None
time.sleep(1)
logger.debug(f"获取到的群聊信息: {rooms}")
for room in rooms['room_list']:
if room['conversation_id'] == conversation_id:
return room
return None
def cdn_download(wework, message, file_name):
data = message["data"]
aes_key = data["cdn"]["aes_key"]
file_size = data["cdn"]["size"]
# 获取当前工作目录,然后与文件名拼接得到保存路径
current_dir = os.getcwd()
save_path = os.path.join(current_dir, "tmp", file_name)
# 下载保存图片到本地
if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
url = data["cdn"]["url"]
auth_key = data["cdn"]["auth_key"]
# result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题缺失了aes_key这个参数
"""
下载wx类型的cdn文件以https开头
"""
data = {
'url': url,
'auth_key': auth_key,
'aes_key': aes_key,
'size': file_size,
'save_path': save_path
}
result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
elif "file_id" in data["cdn"].keys():
if message["type"] == 11042:
file_type = 2
elif message["type"] == 11045:
file_type = 5
file_id = data["cdn"]["file_id"]
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
else:
logger.error(f"something is wrong, data: {data}")
return
# 输出下载结果
logger.debug(f"result: {result}")
def c2c_download_and_convert(wework, message, file_name):
data = message["data"]
aes_key = data["cdn"]["aes_key"]
file_size = data["cdn"]["size"]
file_type = 5
file_id = data["cdn"]["file_id"]
current_dir = os.getcwd()
save_path = os.path.join(current_dir, "tmp", file_name)
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
logger.debug(result)
# 在下载完SILK文件之后立即将其转换为WAV文件
base_name, _ = os.path.splitext(save_path)
wav_file = base_name + ".wav"
pilk.silk_to_wav(save_path, wav_file, rate=24000)
# 删除SILK文件
try:
os.remove(save_path)
except Exception as e:
pass
class WeworkMessage(ChatMessage):
def __init__(self, wework_msg, wework, is_group=False):
try:
super().__init__(wework_msg)
self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
# 使用.get()防止 'send_time' 键不存在时抛出错误
self.create_time = wework_msg['data'].get("send_time")
self.is_group = is_group
self.wework = wework
if wework_msg["type"] == 11041: # 文本消息类型
if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
return
self.ctype = ContextType.TEXT
self.content = wework_msg['data']['content']
elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
base_name, _ = os.path.splitext(file_name)
file_name_2 = base_name + ".wav"
current_dir = os.getcwd()
self.ctype = ContextType.VOICE
self.content = os.path.join(current_dir, "tmp", file_name_2)
self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
current_dir = os.getcwd()
self.ctype = ContextType.IMAGE
self.content = os.path.join(current_dir, "tmp", file_name)
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
elif wework_msg["type"] == 11045: # 文件消息
print("文件消息")
print(wework_msg)
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
file_name = file_name + wework_msg['data']['cdn']['file_name']
current_dir = os.getcwd()
self.ctype = ContextType.FILE
self.content = os.path.join(current_dir, "tmp", file_name)
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
elif wework_msg["type"] == 11047: # 链接消息
self.ctype = ContextType.SHARING
self.content = wework_msg['data']['url']
elif wework_msg["type"] == 11072: # 新成员入群通知
self.ctype = ContextType.JOIN_GROUP
member_list = wework_msg['data']['member_list']
self.actual_user_nickname = member_list[0]['name']
self.actual_user_id = member_list[0]['user_id']
self.content = f"{self.actual_user_nickname}加入了群聊!"
directory = os.path.join(os.getcwd(), "tmp")
rooms = get_with_retry(wework.get_rooms)
if not rooms:
logger.error("更新群信息失败···")
else:
result = {}
for room in rooms['room_list']:
# 获取聊天室ID
room_wxid = room['conversation_id']
# 获取聊天室成员
room_members = wework.get_room_members(room_wxid)
# 将聊天室成员保存到结果字典中
result[room_wxid] = room_members
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=4)
logger.info("有新成员加入,已自动更新群成员列表缓存!")
else:
raise NotImplementedError(
"Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
data = wework_msg['data']
login_info = self.wework.get_login_info()
logger.debug(f"login_info: {login_info}")
nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
user_id = login_info['user_id']
sender_id = data.get('sender')
conversation_id = data.get('conversation_id')
sender_name = data.get("sender_name")
self.from_user_id = user_id if sender_id == user_id else conversation_id
self.from_user_nickname = nickname if sender_id == user_id else sender_name
self.to_user_id = user_id
self.to_user_nickname = nickname
self.other_user_nickname = sender_name
self.other_user_id = conversation_id
if self.is_group:
conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
self.other_user_id = conversation_id
if conversation_id:
room_info = get_room_info(wework=wework, conversation_id=conversation_id)
self.other_user_nickname = room_info.get('nickname', None) if room_info else None
self.from_user_nickname = room_info.get('nickname', None) if room_info else None
at_list = data.get('at_list', [])
tmp_list = []
for at in at_list:
tmp_list.append(at['nickname'])
at_list = tmp_list
logger.debug(f"at_list: {at_list}")
logger.debug(f"nickname: {nickname}")
self.is_at = False
if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
self.is_at = True
self.at_list = at_list
# 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
content = data.get('content', '')
name = nickname
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
if re.search(pattern, content):
logger.debug(f"Wechaty message {self.msg_id} includes at")
self.is_at = True
if not self.actual_user_id:
self.actual_user_id = data.get("sender")
self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
else:
logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
except Exception as e:
logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
raise e

1
cli/VERSION Normal file
View File

@@ -0,0 +1 @@
2.0.4

13
cli/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""CowAgent CLI - Manage your CowAgent from the command line."""
import os as _os
def _read_version():
version_file = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "VERSION")
try:
with open(version_file, "r") as f:
return f.read().strip()
except FileNotFoundError:
return "0.0.0"
__version__ = _read_version()

4
cli/__main__.py Normal file
View File

@@ -0,0 +1,4 @@
"""Allow running as: python -m cli"""
from cli.cli import main
main()

73
cli/cli.py Normal file
View File

@@ -0,0 +1,73 @@
"""CowAgent CLI entry point."""
import click
from cli import __version__
from cli.commands.skill import skill
from cli.commands.process import start, stop, restart, update, status, logs
from cli.commands.context import context
HELP_TEXT = """Usage: cow COMMAND [ARGS]...
CowAgent CLI - Manage your CowAgent instance.
Commands:
help Show this message.
version Show the version.
start Start CowAgent.
stop Stop CowAgent.
restart Restart CowAgent.
update Update CowAgent and restart.
status Show CowAgent running status.
logs View CowAgent logs.
skill Manage CowAgent skills.
Tip: You can also send /help, /skill list, etc. in agent chat."""
class CowCLI(click.Group):
def format_help(self, ctx, formatter):
formatter.write(HELP_TEXT.strip())
formatter.write("\n")
def parse_args(self, ctx, args):
if args and args[0] == 'help':
click.echo(HELP_TEXT.strip())
ctx.exit(0)
return super().parse_args(ctx, args)
@click.group(cls=CowCLI, invoke_without_command=True, context_settings=dict(help_option_names=[]))
@click.pass_context
def main(ctx):
"""CowAgent CLI - Manage your CowAgent instance."""
if ctx.invoked_subcommand is None:
click.echo(HELP_TEXT.strip())
@main.command()
def version():
"""Show the version."""
click.echo(f"cow {__version__}")
@main.command(name='help')
@click.pass_context
def help_cmd(ctx):
"""Show this message."""
click.echo(HELP_TEXT.strip())
main.add_command(skill)
main.add_command(start)
main.add_command(stop)
main.add_command(restart)
main.add_command(update)
main.add_command(status)
main.add_command(logs)
main.add_command(context)
if __name__ == '__main__':
main()

0
cli/commands/__init__.py Normal file
View File

29
cli/commands/context.py Normal file
View File

@@ -0,0 +1,29 @@
"""cow context - Context management commands."""
import click
CHAT_HINT = (
"Context commands operate on the running agent's memory.\n"
"Please send the command in a chat conversation instead:\n\n"
" /context - View current context info\n"
" /context clear - Clear conversation context"
)
@click.group(invoke_without_command=True)
@click.pass_context
def context(ctx):
"""View or manage conversation context.
Context commands need access to the running agent's memory.
Use them in chat conversations: /context or /context clear
"""
if ctx.invoked_subcommand is None:
click.echo(f"\n {CHAT_HINT}\n")
@context.command()
def clear():
"""Clear conversation context (messages history)."""
click.echo(f"\n {CHAT_HINT}\n")

281
cli/commands/process.py Normal file
View File

@@ -0,0 +1,281 @@
"""cow start/stop/restart/status/logs - Process management commands."""
import os
import sys
import subprocess
import time
from typing import Optional
import click
from cli.utils import get_project_root
_IS_WIN = sys.platform == "win32"
def _get_pid_file():
return os.path.join(get_project_root(), ".cow.pid")
def _get_log_file():
return os.path.join(get_project_root(), "nohup.out")
def _is_pid_alive(pid: int) -> bool:
"""Check whether a process is still running (cross-platform)."""
if _IS_WIN:
try:
out = subprocess.check_output(
["tasklist", "/FI", f"PID eq {pid}", "/NH"],
stderr=subprocess.DEVNULL,
)
return str(pid) in out.decode(errors="ignore")
except Exception:
return False
else:
try:
os.kill(pid, 0)
return True
except (ProcessLookupError, PermissionError):
return False
def _kill_pid(pid: int, force: bool = False):
"""Terminate a process by PID (cross-platform)."""
if _IS_WIN:
flag = "/F" if force else ""
cmd = ["taskkill"]
if force:
cmd.append("/F")
cmd.extend(["/PID", str(pid)])
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
else:
import signal
sig = signal.SIGKILL if force else signal.SIGTERM
os.kill(pid, sig)
def _read_pid() -> Optional[int]:
pid_file = _get_pid_file()
if not os.path.exists(pid_file):
return None
try:
with open(pid_file, "r") as f:
pid = int(f.read().strip())
if _is_pid_alive(pid):
return pid
os.remove(pid_file)
return None
except (ValueError, OSError):
try:
os.remove(pid_file)
except OSError:
pass
return None
def _write_pid(pid: int):
with open(_get_pid_file(), "w") as f:
f.write(str(pid))
def _remove_pid():
pid_file = _get_pid_file()
if os.path.exists(pid_file):
os.remove(pid_file)
@click.command()
@click.option("--foreground", "-f", is_flag=True, help="Run in foreground (don't daemonize)")
@click.option("--no-logs", is_flag=True, help="Don't tail logs after starting")
def start(foreground, no_logs):
"""Start CowAgent."""
pid = _read_pid()
if pid:
click.echo(f"CowAgent is already running (PID: {pid}).")
return
root = get_project_root()
app_py = os.path.join(root, "app.py")
if not os.path.exists(app_py):
click.echo("Error: app.py not found in project root.", err=True)
sys.exit(1)
python = sys.executable
if foreground:
click.echo("Starting CowAgent in foreground...")
if _IS_WIN:
sys.exit(subprocess.call([python, app_py], cwd=root))
else:
os.execv(python, [python, app_py])
else:
log_file = _get_log_file()
click.echo("Starting CowAgent...")
popen_kwargs = dict(cwd=root)
if _IS_WIN:
CREATE_NO_WINDOW = 0x08000000
popen_kwargs["creationflags"] = (
subprocess.CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW
)
else:
popen_kwargs["start_new_session"] = True
with open(log_file, "a") as log:
proc = subprocess.Popen(
[python, app_py],
stdout=log,
stderr=log,
**popen_kwargs,
)
_write_pid(proc.pid)
click.echo(click.style(f"✓ CowAgent started (PID: {proc.pid})", fg="green"))
click.echo(f" Logs: {log_file}")
if not no_logs:
click.echo(" Press Ctrl+C to stop tailing logs.\n")
_tail_log(log_file)
@click.command()
def stop():
"""Stop CowAgent."""
pid = _read_pid()
if not pid:
click.echo("CowAgent is not running.")
return
click.echo(f"Stopping CowAgent (PID: {pid})...")
try:
_kill_pid(pid)
for _ in range(30):
time.sleep(0.1)
if not _is_pid_alive(pid):
break
else:
_kill_pid(pid, force=True)
except (ProcessLookupError, OSError):
pass
_remove_pid()
click.echo(click.style("✓ CowAgent stopped.", fg="green"))
@click.command()
@click.option("--no-logs", is_flag=True, help="Don't tail logs after restarting")
@click.pass_context
def restart(ctx, no_logs):
"""Restart CowAgent."""
ctx.invoke(stop)
time.sleep(1)
ctx.invoke(start, no_logs=no_logs)
@click.command()
@click.pass_context
def update(ctx):
"""Update CowAgent and restart."""
root = get_project_root()
# 1. Git pull while service is still running
if os.path.isdir(os.path.join(root, ".git")):
click.echo("Pulling latest code...")
ret = subprocess.call(["git", "pull"], cwd=root)
if ret != 0:
click.echo("Error: git pull failed.", err=True)
sys.exit(1)
else:
click.echo("Not a git repository, skipping code update.")
# 2. Stop service
ctx.invoke(stop)
# 3. Install dependencies
python = sys.executable
req_file = os.path.join(root, "requirements.txt")
if os.path.exists(req_file):
click.echo("Installing dependencies...")
subprocess.call(
[python, "-m", "pip", "install", "-r", "requirements.txt", "-q"],
cwd=root,
)
click.echo("Reinstalling cow CLI...")
subprocess.call(
[python, "-m", "pip", "install", "-e", ".", "-q"],
cwd=root,
)
# 4. Start service
click.echo("")
time.sleep(1)
ctx.invoke(start, no_logs=True)
@click.command()
def status():
"""Show CowAgent running status."""
from cli import __version__
from cli.utils import load_config_json
pid = _read_pid()
if pid:
click.echo(click.style(f"● CowAgent is running (PID: {pid})", fg="green"))
else:
click.echo(click.style("● CowAgent is not running", fg="red"))
click.echo(f" 版本: v{__version__}")
cfg = load_config_json()
if cfg:
channel = cfg.get("channel_type", "unknown")
if isinstance(channel, list):
channel = ", ".join(channel)
click.echo(f" 通道: {channel}")
click.echo(f" 模型: {cfg.get('model', 'unknown')}")
mode = "Agent" if cfg.get("agent") else "Chat"
click.echo(f" 模式: {mode}")
@click.command()
@click.option("--follow", "-f", is_flag=True, help="Follow log output")
@click.option("--lines", "-n", default=50, help="Number of lines to show")
def logs(follow, lines):
"""View CowAgent logs."""
log_file = _get_log_file()
if not os.path.exists(log_file):
click.echo("No log file found.")
return
if follow:
_tail_log(log_file, lines)
else:
_print_last_lines(log_file, lines)
def _print_last_lines(file_path: str, n: int = 50):
"""Print the last N lines of a file (cross-platform)."""
try:
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
all_lines = f.readlines()
for line in all_lines[-n:]:
click.echo(line, nl=False)
except Exception as e:
click.echo(f"Error reading log file: {e}", err=True)
def _tail_log(log_file: str, lines: int = 50):
"""Follow log file output. Blocks until Ctrl+C (cross-platform)."""
_print_last_lines(log_file, lines)
try:
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
f.seek(0, 2)
while True:
line = f.readline()
if line:
click.echo(line, nl=False)
else:
time.sleep(0.3)
except KeyboardInterrupt:
pass

799
cli/commands/skill.py Normal file
View File

@@ -0,0 +1,799 @@
"""cow skill - Skill management commands."""
import os
import re
import sys
import json
import hashlib
import shutil
import zipfile
import tempfile
from urllib.parse import urlparse
import click
import requests
from cli.utils import (
get_project_root,
get_skills_dir,
get_builtin_skills_dir,
load_skills_config,
SKILL_HUB_API,
)
_SAFE_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_\-]{0,63}$")
_GITHUB_URL_RE = re.compile(
r"^https?://github\.com/([^/]+)/([^/]+?)(?:\.git)?(?:/(?:tree|blob)/([^/]+)(?:/(.+))?)?/?$"
)
def _parse_github_url(url: str):
"""Parse a full GitHub URL into (owner, repo, branch, subpath).
Returns None if the URL doesn't match.
Supported formats:
https://github.com/owner/repo
https://github.com/owner/repo/tree/branch
https://github.com/owner/repo/tree/branch/path/to/skill
https://github.com/owner/repo/blob/branch/path/to/skill
"""
m = _GITHUB_URL_RE.match(url.strip())
if not m:
return None
owner, repo, branch, subpath = m.groups()
return owner, repo, branch or "main", subpath
def _download_github_dir(owner, repo, branch, subpath, dest_dir):
"""Download a subdirectory from GitHub using the Contents API.
Recursively fetches all files under the given subpath and writes them
to dest_dir. Raises on any network or API error.
"""
api_url = f"https://api.github.com/repos/{owner}/{repo}/contents/{subpath}?ref={branch}"
resp = requests.get(api_url, timeout=30, headers={"Accept": "application/vnd.github.v3+json"})
resp.raise_for_status()
items = resp.json()
if isinstance(items, dict):
items = [items]
for item in items:
rel_path = item["path"]
if subpath:
rel_path = rel_path[len(subpath.strip("/")):].lstrip("/")
local_path = os.path.join(dest_dir, rel_path)
if item["type"] == "file":
os.makedirs(os.path.dirname(local_path), exist_ok=True)
dl_url = item.get("download_url")
if not dl_url:
continue
file_resp = requests.get(dl_url, timeout=30)
file_resp.raise_for_status()
with open(local_path, "wb") as f:
f.write(file_resp.content)
elif item["type"] == "dir":
os.makedirs(local_path, exist_ok=True)
child_subpath = item["path"]
_download_github_dir(owner, repo, branch, child_subpath, dest_dir)
def _register_installed_skill(name: str, source: str = "cowhub"):
"""Register a newly installed skill into skills_config.json.
source values: builtin, cow, github, clawhub, linkai, local, url
"""
skills_dir = get_skills_dir()
config_path = os.path.join(skills_dir, "skills_config.json")
config = {}
if os.path.exists(config_path):
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
except Exception:
config = {}
if name in config:
return
skill_dir = os.path.join(skills_dir, name)
description = _read_skill_description(skill_dir) or ""
config[name] = {
"name": name,
"description": description,
"source": source,
"enabled": True,
"category": "skill",
}
try:
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4, ensure_ascii=False)
except Exception:
pass
def _parse_skill_frontmatter(content: str) -> dict:
"""Parse YAML frontmatter from SKILL.md content and return a dict with name/description."""
result = {}
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
if not match:
return result
for line in match.group(1).split('\n'):
line = line.strip()
for key in ('name', 'description'):
if line.startswith(f'{key}:'):
val = line[len(key) + 1:].strip()
result[key] = val.strip('"').strip("'")
return result
def _read_skill_description(skill_dir: str) -> str:
"""Read the description from a skill's SKILL.md frontmatter."""
skill_md = os.path.join(skill_dir, "SKILL.md")
if not os.path.exists(skill_md):
return ""
try:
with open(skill_md, "r", encoding="utf-8") as f:
content = f.read()
return _parse_skill_frontmatter(content).get("description", "")
except Exception:
return ""
def _install_url(url: str):
"""Install a skill from a direct SKILL.md URL."""
click.echo(f"Downloading SKILL.md from {url} ...")
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
except Exception as e:
click.echo(f"Error: Failed to download SKILL.md: {e}", err=True)
sys.exit(1)
content = resp.text
fm = _parse_skill_frontmatter(content)
skill_name = fm.get("name")
if not skill_name:
click.echo("Error: SKILL.md missing 'name' field in frontmatter.", err=True)
sys.exit(1)
skill_name = skill_name.strip()
_validate_skill_name(skill_name)
skills_dir = get_skills_dir()
os.makedirs(skills_dir, exist_ok=True)
skill_dir = os.path.join(skills_dir, skill_name)
if os.path.isdir(skill_dir):
click.echo(f"Skill '{skill_name}' already exists. Overwriting SKILL.md ...")
os.makedirs(skill_dir, exist_ok=True)
with open(os.path.join(skill_dir, "SKILL.md"), "w", encoding="utf-8") as f:
f.write(content)
_register_installed_skill(skill_name, source="url")
_print_install_success(skill_name, "url")
def _print_install_success(name: str, source: str):
"""Print a unified install success message with description and source."""
skills_dir = get_skills_dir()
desc = _read_skill_description(os.path.join(skills_dir, name))
click.echo(click.style(f"{name}", fg="green"))
if desc:
if len(desc) > 60:
desc = desc[:57] + ""
click.echo(f" {desc}")
click.echo(f" 来源: {source}")
def _validate_skill_name(name: str):
"""Reject names that contain path traversal or special characters."""
if not _SAFE_NAME_RE.match(name):
click.echo(
f"Error: Invalid skill name '{name}'. "
"Use only letters, digits, hyphens, and underscores.",
err=True,
)
sys.exit(1)
def _validate_github_spec(spec: str):
"""Reject specs that don't look like owner/repo."""
if not re.match(r"^[a-zA-Z0-9_\-]+/[a-zA-Z0-9_.\-]+$", spec):
click.echo(f"Error: Invalid GitHub spec '{spec}'. Expected format: owner/repo", err=True)
sys.exit(1)
def _safe_extractall(zf: zipfile.ZipFile, dest: str):
"""Extract zip while guarding against Zip Slip (path traversal)."""
dest = os.path.realpath(dest)
for member in zf.infolist():
target = os.path.realpath(os.path.join(dest, member.filename))
if not target.startswith(dest + os.sep) and target != dest:
raise ValueError(f"Unsafe zip entry detected: {member.filename}")
zf.extractall(dest)
def _verify_checksum(content: bytes, expected: str):
"""Verify SHA-256 checksum of downloaded content.
Returns True if checksum matches or no expected value provided.
Exits with error if mismatch.
"""
if not expected:
return True
actual = hashlib.sha256(content).hexdigest()
if actual != expected.lower():
click.echo(
f"Error: Checksum mismatch!\n"
f" Expected: {expected}\n"
f" Actual: {actual}\n"
f"The downloaded package may have been tampered with.",
err=True,
)
sys.exit(1)
return True
@click.group()
def skill():
"""Manage CowAgent skills."""
pass
# ------------------------------------------------------------------
# cow skill list
# ------------------------------------------------------------------
@skill.command("list")
@click.option("--remote", is_flag=True, help="Browse skills on Skill Hub")
@click.option("--page", default=1, type=int, help="Page number for remote listing")
def skill_list(remote, page):
"""List installed skills or browse Skill Hub."""
if remote:
_list_remote(page=page)
else:
_list_local()
def _list_local():
"""List locally installed skills."""
config = load_skills_config()
skills_dir = get_skills_dir()
builtin_dir = get_builtin_skills_dir()
if not config:
# Fallback: scan directories directly
entries = []
for d in [builtin_dir, skills_dir]:
if not os.path.isdir(d):
continue
source = "builtin" if d == builtin_dir else "custom"
for name in sorted(os.listdir(d)):
skill_path = os.path.join(d, name)
if os.path.isdir(skill_path) and not name.startswith("."):
has_skill_md = os.path.exists(os.path.join(skill_path, "SKILL.md"))
if has_skill_md:
entries.append({"name": name, "source": source, "enabled": True, "description": ""})
if not entries:
click.echo("No skills installed.")
return
_print_skill_table(entries)
return
entries = sorted(config.values(), key=lambda x: x.get("name", ""))
if not entries:
click.echo("No skills installed.")
return
_print_skill_table(entries)
def _print_skill_table(entries):
"""Print skills as a formatted table."""
name_w = max(len(e.get("name", "")) for e in entries)
name_w = max(name_w, 4) + 2
desc_w = 40
header = f"{'Name':<{name_w}} {'Status':<10} {'Source':<10} {'Description'}"
click.echo(f"\n Installed skills ({len(entries)})\n")
click.echo(f" {header}")
click.echo(f" {'' * (name_w + 10 + 10 + desc_w)}")
for e in entries:
name = e.get("name", "")
enabled = e.get("enabled", True)
source = e.get("source", "")
desc = e.get("description", "") or ""
if len(desc) > desc_w:
desc = desc[:desc_w - 3] + "..."
status_icon = click.style("✓ on ", fg="green") if enabled else click.style("✗ off", fg="red")
click.echo(f" {name:<{name_w}} {status_icon} {source:<10} {desc}")
click.echo()
_REMOTE_PAGE_SIZE = 10
def _list_remote(page: int = 1):
"""List skills from remote Skill Hub with server-side pagination."""
try:
resp = requests.get(
f"{SKILL_HUB_API}/skills",
params={"page": page, "limit": _REMOTE_PAGE_SIZE},
timeout=10,
)
resp.raise_for_status()
data = resp.json()
except Exception as e:
click.echo(f"Error: Failed to fetch from Skill Hub: {e}", err=True)
sys.exit(1)
skills = data.get("skills", [])
total = data.get("total", len(skills))
if not skills and page == 1:
click.echo("No skills available on Skill Hub.")
return
total_pages = max(1, (total + _REMOTE_PAGE_SIZE - 1) // _REMOTE_PAGE_SIZE)
page = min(page, total_pages)
installed = set(load_skills_config().keys())
name_w = max((len(s.get("name", "")) for s in skills), default=4)
name_w = max(name_w, 4) + 2
click.echo(f"\n Skill Hub ({total} available) — page {page}/{total_pages}\n")
click.echo(f" {'Name':<{name_w}} {'Status':<12} {'Description'}")
click.echo(f" {'' * (name_w + 12 + 50)}")
for s in skills:
name = s.get("name", "")
desc = s.get("description", "") or s.get("display_name", "")
if len(desc) > 50:
desc = desc[:47] + "..."
status = click.style("installed", fg="green") if name in installed else ""
click.echo(f" {name:<{name_w}} {status:<12} {desc}")
click.echo()
nav_parts = []
if page > 1:
nav_parts.append(f"cow skill list --remote --page {page - 1}")
if page < total_pages:
nav_parts.append(f"cow skill list --remote --page {page + 1}")
if nav_parts:
click.echo(f" Navigate: {' | '.join(nav_parts)}")
click.echo(f" Install: cow skill install <name>\n")
# ------------------------------------------------------------------
# cow skill search
# ------------------------------------------------------------------
@skill.command()
@click.argument("query")
def search(query):
"""Search skills on Skill Hub."""
try:
resp = requests.get(f"{SKILL_HUB_API}/skills/search", params={"q": query}, timeout=10)
resp.raise_for_status()
data = resp.json()
except Exception as e:
click.echo(f"Error: Failed to search Skill Hub: {e}", err=True)
sys.exit(1)
skills = data.get("skills", [])
if not skills:
click.echo(f'No skills found for "{query}".')
return
installed = set(load_skills_config().keys())
name_w = max(len(s.get("name", "")) for s in skills)
name_w = max(name_w, 4) + 2
click.echo(f'\n Search results for "{query}" ({len(skills)} found)\n')
click.echo(f" {'Name':<{name_w}} {'Status':<12} {'Description'}")
click.echo(f" {'' * (name_w + 12 + 50)}")
for s in skills:
name = s.get("name", "")
desc = s.get("description", "") or s.get("display_name", "")
if len(desc) > 50:
desc = desc[:47] + "..."
status = click.style("installed", fg="green") if name in installed else ""
click.echo(f" {name:<{name_w}} {status:<12} {desc}")
click.echo(f"\n Install with: cow skill install <name>\n")
# ------------------------------------------------------------------
# cow skill install
# ------------------------------------------------------------------
@skill.command()
@click.argument("name")
def install(name):
"""Install a skill from Skill Hub, GitHub, or a SKILL.md URL.
Examples:
cow skill install pptx
cow skill install github:owner/repo
cow skill install github:owner/repo#path/to/skill
cow skill install https://github.com/owner/repo/tree/main/path/to/skill
cow skill install https://example.com/path/to/SKILL.md
"""
if name.startswith(("http://", "https://")) and name.rstrip("/").endswith("SKILL.md"):
# GitHub SKILL.md → strip filename and install the whole directory
dir_url = re.sub(r'/SKILL\.md/?$', '', name)
gh = _parse_github_url(dir_url)
if gh:
owner, repo, branch, subpath = gh
spec = f"{owner}/{repo}"
skill_name = subpath.rstrip("/").split("/")[-1] if subpath else repo
_install_github(spec, subpath=subpath, skill_name=skill_name, branch=branch)
return
_install_url(name)
return
parsed = _parse_github_url(name)
if parsed:
owner, repo, branch, subpath = parsed
spec = f"{owner}/{repo}"
skill_name = subpath.rstrip("/").split("/")[-1] if subpath else repo
_install_github(spec, subpath=subpath, skill_name=skill_name, branch=branch)
elif name.startswith("github:"):
skill_name = name[7:]
_validate_skill_name(skill_name)
_install_hub(skill_name)
elif name.startswith("clawhub:"):
skill_name = name[8:]
_validate_skill_name(skill_name)
_install_hub(skill_name, provider="clawhub")
else:
_validate_skill_name(name)
_install_hub(name)
def _install_hub(name, provider=None):
"""Install a skill from Skill Hub."""
skills_dir = get_skills_dir()
os.makedirs(skills_dir, exist_ok=True)
click.echo(f"Fetching skill info for '{name}'...")
try:
body = {}
if provider:
body["provider"] = provider
resp = requests.post(
f"{SKILL_HUB_API}/skills/{name}/download",
json=body,
timeout=15,
)
resp.raise_for_status()
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 404:
click.echo(f"Error: Skill '{name}' not found on Skill Hub.", err=True)
else:
click.echo(f"Error: Failed to fetch skill: {e}", err=True)
sys.exit(1)
except Exception as e:
click.echo(f"Error: Failed to connect to Skill Hub: {e}", err=True)
sys.exit(1)
content_type = resp.headers.get("Content-Type", "")
if "application/json" in content_type:
data = resp.json()
source_type = data.get("source_type")
if source_type == "github":
source_url = data.get("source_url", "")
parsed_url = _parse_github_url(source_url)
if parsed_url:
owner, repo, branch, subpath = parsed_url
click.echo(f"Source: GitHub ({source_url})")
_install_github(f"{owner}/{repo}", subpath=subpath, skill_name=name, branch=branch)
else:
_validate_github_spec(source_url)
click.echo(f"Source: GitHub ({source_url})")
_install_github(source_url, skill_name=name)
return
if source_type == "registry":
download_url = data.get("download_url")
if download_url:
parsed = urlparse(download_url)
if parsed.scheme != "https":
click.echo(f"Error: Refusing to download from non-HTTPS URL.", err=True)
sys.exit(1)
provider = data.get("source_provider", "registry")
expected_checksum = data.get("checksum") or data.get("sha256")
click.echo(f"Source: {provider}")
click.echo("Downloading skill package...")
try:
dl_resp = requests.get(download_url, timeout=60, allow_redirects=True)
dl_resp.raise_for_status()
except Exception as e:
click.echo(f"Error: Failed to download from {provider}: {e}", err=True)
sys.exit(1)
_verify_checksum(dl_resp.content, expected_checksum)
_install_zip_bytes(dl_resp.content, name, skills_dir)
_register_installed_skill(name, source=provider)
_print_install_success(name, provider)
else:
click.echo(f"Error: Unsupported registry provider.", err=True)
sys.exit(1)
return
if "redirect" in data:
source_url = data.get("source_url", "")
parsed_url = _parse_github_url(source_url)
if parsed_url:
owner, repo, branch, subpath = parsed_url
click.echo(f"Source: GitHub ({source_url})")
_install_github(f"{owner}/{repo}", subpath=subpath, skill_name=name, branch=branch)
else:
_validate_github_spec(source_url)
click.echo(f"Source: GitHub ({source_url})")
_install_github(source_url, skill_name=name)
return
elif "application/zip" in content_type:
click.echo("Downloading skill package...")
expected_checksum = resp.headers.get("X-Checksum-Sha256")
_verify_checksum(resp.content, expected_checksum)
_install_zip_bytes(resp.content, name, skills_dir)
_register_installed_skill(name)
_print_install_success(name, "cowhub")
return
click.echo(f"Error: Unexpected response from Skill Hub.", err=True)
sys.exit(1)
def _install_github(spec, subpath=None, skill_name=None, branch="main", source="github"):
"""Install a skill from a GitHub repo.
spec format: owner/repo or owner/repo#path
"""
if "#" in spec and not subpath:
spec, subpath = spec.split("#", 1)
_validate_github_spec(spec)
if not skill_name:
skill_name = subpath.rstrip("/").split("/")[-1] if subpath else spec.split("/")[-1]
_validate_skill_name(skill_name)
skills_dir = get_skills_dir()
os.makedirs(skills_dir, exist_ok=True)
target_dir = os.path.join(skills_dir, skill_name)
owner, repo = spec.split("/", 1)
# For subpath installs, try GitHub Contents API first (avoids downloading entire repo)
if subpath:
click.echo(f"Downloading from GitHub: {spec}/{subpath} (branch: {branch})...")
try:
with tempfile.TemporaryDirectory() as tmp_dir:
api_dest = os.path.join(tmp_dir, skill_name)
os.makedirs(api_dest)
_download_github_dir(owner, repo, branch, subpath.strip("/"), api_dest)
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
shutil.copytree(api_dest, target_dir)
_register_installed_skill(skill_name, source=source)
_print_install_success(skill_name, source)
return
except Exception:
click.echo("Contents API unavailable, falling back to zip download...")
# Fallback: download full repo zip
zip_url = f"https://github.com/{spec}/archive/refs/heads/{branch}.zip"
click.echo(f"Downloading from GitHub: {spec} (branch: {branch})...")
try:
resp = requests.get(zip_url, timeout=60, allow_redirects=True)
resp.raise_for_status()
except Exception as e:
click.echo(f"Error: Failed to download from GitHub: {e}", err=True)
sys.exit(1)
with tempfile.TemporaryDirectory() as tmp_dir:
zip_path = os.path.join(tmp_dir, "repo.zip")
with open(zip_path, "wb") as f:
f.write(resp.content)
extract_dir = os.path.join(tmp_dir, "extracted")
with zipfile.ZipFile(zip_path, "r") as zf:
_safe_extractall(zf, extract_dir)
top_items = [d for d in os.listdir(extract_dir) if not d.startswith(".")]
repo_root = extract_dir
if len(top_items) == 1 and os.path.isdir(os.path.join(extract_dir, top_items[0])):
repo_root = os.path.join(extract_dir, top_items[0])
if subpath:
source_dir = os.path.join(repo_root, subpath.strip("/"))
if not os.path.isdir(source_dir):
click.echo(f"Error: Path '{subpath}' not found in repository.", err=True)
sys.exit(1)
else:
source_dir = repo_root
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
shutil.copytree(source_dir, target_dir)
_register_installed_skill(skill_name, source=source)
_print_install_success(skill_name, source)
def _install_zip_bytes(content, name, skills_dir):
"""Extract a zip archive into the skills directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
zip_path = os.path.join(tmp_dir, "package.zip")
with open(zip_path, "wb") as f:
f.write(content)
extract_dir = os.path.join(tmp_dir, "extracted")
with zipfile.ZipFile(zip_path, "r") as zf:
_safe_extractall(zf, extract_dir)
top_items = [d for d in os.listdir(extract_dir) if not d.startswith(".")]
source = extract_dir
if len(top_items) == 1 and os.path.isdir(os.path.join(extract_dir, top_items[0])):
source = os.path.join(extract_dir, top_items[0])
target = os.path.join(skills_dir, name)
if os.path.exists(target):
shutil.rmtree(target)
shutil.copytree(source, target)
# ------------------------------------------------------------------
# cow skill uninstall
# ------------------------------------------------------------------
@skill.command()
@click.argument("name")
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation")
def uninstall(name, yes):
"""Uninstall a skill."""
_validate_skill_name(name)
skills_dir = get_skills_dir()
skill_dir = os.path.join(skills_dir, name)
if not os.path.exists(skill_dir):
click.echo(f"Error: Skill '{name}' is not installed.", err=True)
sys.exit(1)
if not yes:
click.confirm(f"Uninstall skill '{name}'?", abort=True)
shutil.rmtree(skill_dir)
config_path = os.path.join(skills_dir, "skills_config.json")
if os.path.exists(config_path):
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
config.pop(name, None)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4, ensure_ascii=False)
except Exception:
pass
click.echo(click.style(f"✓ Skill '{name}' uninstalled.", fg="green"))
# ------------------------------------------------------------------
# cow skill enable / disable
# ------------------------------------------------------------------
@skill.command()
@click.argument("name")
def enable(name):
"""Enable a skill."""
_set_enabled(name, True)
@skill.command()
@click.argument("name")
def disable(name):
"""Disable a skill."""
_set_enabled(name, False)
def _set_enabled(name, enabled):
_validate_skill_name(name)
skills_dir = get_skills_dir()
config_path = os.path.join(skills_dir, "skills_config.json")
if not os.path.exists(config_path):
click.echo(f"Error: No skills config found.", err=True)
sys.exit(1)
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
except Exception as e:
click.echo(f"Error: Failed to read skills config: {e}", err=True)
sys.exit(1)
if name not in config:
click.echo(f"Error: Skill '{name}' not found in config.", err=True)
sys.exit(1)
config[name]["enabled"] = enabled
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4, ensure_ascii=False)
state = "enabled" if enabled else "disabled"
icon = "" if enabled else ""
color = "green" if enabled else "yellow"
click.echo(click.style(f"{icon} Skill '{name}' {state}.", fg=color))
# ------------------------------------------------------------------
# cow skill info
# ------------------------------------------------------------------
@skill.command()
@click.argument("name")
def info(name):
"""Show details about an installed skill."""
_validate_skill_name(name)
skills_dir = get_skills_dir()
builtin_dir = get_builtin_skills_dir()
skill_dir = None
source = None
config = load_skills_config()
for d, src in [(skills_dir, "custom"), (builtin_dir, "builtin")]:
candidate = os.path.join(d, name)
if os.path.isdir(candidate):
skill_dir = candidate
source = config.get(name, {}).get("source") or src
break
if not skill_dir:
click.echo(f"Error: Skill '{name}' not found.", err=True)
sys.exit(1)
skill_md = os.path.join(skill_dir, "SKILL.md")
if not os.path.exists(skill_md):
click.echo(f"Skill directory: {skill_dir}")
click.echo("No SKILL.md found.")
return
with open(skill_md, "r", encoding="utf-8") as f:
content = f.read()
config = load_skills_config()
entry = config.get(name, {})
enabled = entry.get("enabled", True)
status_str = click.style("✓ enabled", fg="green") if enabled else click.style("✗ disabled", fg="red")
click.echo(f"\n Skill: {name}")
click.echo(f" Source: {source}")
click.echo(f" Status: {status_str}")
click.echo(f" Path: {skill_dir}")
click.echo(f"\n{'' * 60}")
# Show first ~30 lines of SKILL.md as a preview
lines = content.split("\n")
preview = "\n".join(lines[:30])
click.echo(preview)
if len(lines) > 30:
click.echo(f"\n ... ({len(lines) - 30} more lines, see {skill_md})")
click.echo()

62
cli/utils.py Normal file
View File

@@ -0,0 +1,62 @@
"""Shared utilities for cow CLI."""
import os
import sys
import json
def get_project_root() -> str:
"""Get the CowAgent project root directory."""
# cli/ is directly under the project root
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def get_workspace_dir() -> str:
"""Get the agent workspace directory from config, defaulting to ~/cow."""
config = load_config_json()
workspace = config.get("agent_workspace", "~/cow")
return os.path.expanduser(workspace)
def get_skills_dir() -> str:
"""Get the custom skills directory."""
return os.path.join(get_workspace_dir(), "skills")
def get_builtin_skills_dir() -> str:
"""Get the builtin skills directory."""
return os.path.join(get_project_root(), "skills")
def load_config_json() -> dict:
"""Load config.json from project root."""
config_path = os.path.join(get_project_root(), "config.json")
if not os.path.exists(config_path):
return {}
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def load_skills_config() -> dict:
"""Load skills_config.json from the custom skills directory."""
path = os.path.join(get_skills_dir(), "skills_config.json")
if not os.path.exists(path):
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def ensure_sys_path():
"""Add project root to sys.path so we can import agent modules."""
root = get_project_root()
if root not in sys.path:
sys.path.insert(0, root)
SKILL_HUB_API = "https://skills.cowagent.ai/api"

765
common/cloud_client.py Normal file
View File

@@ -0,0 +1,765 @@
"""
Cloud management client for connecting to the LinkAI control console.
Handles remote configuration sync, message push, and skill management
via the LinkAI socket protocol.
NOTE: By default, no cloud-related config is enabled. The application runs
entirely locally without connecting to any remote service. The cloud client
is only activated when BOTH of the following conditions are met:
1. ``use_linkai`` is set to True in config (checked in app.py before
importing this module).
2. ``cloud_deployment_id`` (or env CLOUD_DEPLOYMENT_ID) is non-empty
(checked in app.py and again in the ``start()`` function below).
If either condition is missing, this module is never loaded and the
program continues as a purely local application.
"""
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from linkai import LinkAIClient, PushMsg
from config import conf, pconf, plugin_config, available_setting, write_plugin_config, get_root
from plugins import PluginManager
import threading
import time
import json
import os
chat_client: LinkAIClient
CHANNEL_ACTIONS = {"channel_create", "channel_update", "channel_delete"}
# channelType -> config key mapping for app credentials
CREDENTIAL_MAP = {
"feishu": ("feishu_app_id", "feishu_app_secret"),
"dingtalk": ("dingtalk_client_id", "dingtalk_client_secret"),
"wecom_bot": ("wecom_bot_id", "wecom_bot_secret"),
"qq": ("qq_app_id", "qq_app_secret"),
"wechatmp": ("wechatmp_app_id", "wechatmp_app_secret"),
"wechatmp_service": ("wechatmp_app_id", "wechatmp_app_secret"),
"wechatcom_app": ("wechatcomapp_agent_id", "wechatcomapp_secret"),
}
class CloudClient(LinkAIClient):
def __init__(self, api_key: str, channel, host: str = ""):
super().__init__(api_key, host)
self.channel = channel
self.client_type = channel.channel_type
self.channel_mgr = None
self._skill_service = None
self._memory_service = None
self._chat_service = None
@property
def skill_service(self):
"""Lazy-init SkillService so it is available once SkillManager exists."""
if self._skill_service is None:
try:
from agent.skills.manager import SkillManager
from agent.skills.service import SkillService
from config import conf
from common.utils import expand_path
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
self._skill_service = SkillService(manager)
logger.debug("[CloudClient] SkillService initialised")
except Exception as e:
logger.error(f"[CloudClient] Failed to init SkillService: {e}")
return self._skill_service
@property
def memory_service(self):
"""Lazy-init MemoryService."""
if self._memory_service is None:
try:
from agent.memory.service import MemoryService
from config import conf
from common.utils import expand_path
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
self._memory_service = MemoryService(workspace_root)
logger.debug("[CloudClient] MemoryService initialised")
except Exception as e:
logger.error(f"[CloudClient] Failed to init MemoryService: {e}")
return self._memory_service
@property
def chat_service(self):
"""Lazy-init ChatService (requires AgentBridge via Bridge singleton)."""
if self._chat_service is None:
try:
from agent.chat.service import ChatService
from bridge.bridge import Bridge
agent_bridge = Bridge().get_agent_bridge()
self._chat_service = ChatService(agent_bridge)
logger.debug("[CloudClient] ChatService initialised")
except Exception as e:
logger.error(f"[CloudClient] Failed to init ChatService: {e}")
return self._chat_service
# ------------------------------------------------------------------
# message push callback
# ------------------------------------------------------------------
def on_message(self, push_msg: PushMsg):
session_id = push_msg.session_id
msg_content = push_msg.msg_content
logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
context = Context()
context.type = ContextType.TEXT
context["receiver"] = session_id
context["isgroup"] = push_msg.is_group
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
# ------------------------------------------------------------------
# config callback
# ------------------------------------------------------------------
def on_config(self, config: dict):
if not self.client_id:
return
logger.info(f"[CloudClient] Loading remote config: {config}")
action = config.get("action")
if action in CHANNEL_ACTIONS:
self._dispatch_channel_action(action, config.get("data", {}))
return
if config.get("enabled") != "Y":
return
local_config = conf()
need_restart_channel = False
for key in config.keys():
if key in available_setting and config.get(key) is not None:
local_config[key] = config.get(key)
# Voice settings
reply_voice_mode = config.get("reply_voice_mode")
if reply_voice_mode:
if reply_voice_mode == "voice_reply_voice":
local_config["voice_reply_voice"] = True
local_config["always_reply_voice"] = False
elif reply_voice_mode == "always_reply_voice":
local_config["always_reply_voice"] = True
local_config["voice_reply_voice"] = True
elif reply_voice_mode == "no_reply_voice":
local_config["always_reply_voice"] = False
local_config["voice_reply_voice"] = False
# Model configuration
if config.get("model"):
local_config["model"] = config.get("model")
# Channel configuration (legacy single-channel path)
if config.get("channelType"):
if local_config.get("channel_type") != config.get("channelType"):
local_config["channel_type"] = config.get("channelType")
need_restart_channel = True
# Channel-specific app credentials (legacy single-channel path)
current_channel_type = local_config.get("channel_type", "")
if self._set_channel_credentials(local_config, current_channel_type,
config.get("app_id"), config.get("app_secret")):
need_restart_channel = True
if config.get("admin_password"):
if not pconf("Godcmd"):
write_plugin_config({"Godcmd": {"password": config.get("admin_password"), "admin_users": []}})
else:
pconf("Godcmd")["password"] = config.get("admin_password")
PluginManager().instances["GODCMD"].reload()
if config.get("group_app_map") and pconf("linkai"):
local_group_map = {}
for mapping in config.get("group_app_map"):
local_group_map[mapping.get("group_name")] = mapping.get("app_code")
pconf("linkai")["group_app_map"] = local_group_map
PluginManager().instances["LINKAI"].reload()
if config.get("text_to_image") and config.get("text_to_image") == "midjourney" and pconf("linkai"):
if pconf("linkai")["midjourney"]:
pconf("linkai")["midjourney"]["enabled"] = True
pconf("linkai")["midjourney"]["use_image_create_prefix"] = True
elif config.get("text_to_image") and config.get("text_to_image") in ["dall-e-2", "dall-e-3"]:
if pconf("linkai")["midjourney"]:
pconf("linkai")["midjourney"]["use_image_create_prefix"] = False
self._save_config_to_file(local_config)
if need_restart_channel:
self._restart_channel(local_config.get("channel_type", ""))
# ------------------------------------------------------------------
# channel CRUD operations
# ------------------------------------------------------------------
def _dispatch_channel_action(self, action: str, data: dict):
channel_type = data.get("channelType")
if not channel_type:
logger.warning(f"[CloudClient] Channel action '{action}' missing channelType, data={data}")
return
logger.info(f"[CloudClient] Channel action: {action}, channelType={channel_type}")
if action == "channel_create":
self._handle_channel_create(channel_type, data)
elif action == "channel_update":
self._handle_channel_update(channel_type, data)
elif action == "channel_delete":
self._handle_channel_delete(channel_type, data)
def _handle_channel_create(self, channel_type: str, data: dict):
local_config = conf()
cred_changed = self._set_channel_credentials(
local_config, channel_type, data.get("appId"), data.get("appSecret"))
self._add_channel_type(local_config, channel_type)
self._save_config_to_file(local_config)
if not self.channel_mgr:
return
existing_ch = self.channel_mgr.get_channel(channel_type)
skip_restart = existing_ch and not cred_changed
if skip_restart and channel_type in ("weixin", "wx"):
login_status = getattr(existing_ch, "login_status", "")
if login_status != "logged_in":
skip_restart = False
logger.info(f"[CloudClient] Channel '{channel_type}' not logged in "
f"(status={login_status}), forcing restart")
if skip_restart:
logger.info(f"[CloudClient] Channel '{channel_type}' already running with same config, "
"skip restart, reporting status only")
threading.Thread(
target=self._report_channel_startup, args=(channel_type,), daemon=True
).start()
return
threading.Thread(
target=self._do_add_channel, args=(channel_type,), daemon=True
).start()
def _handle_channel_update(self, channel_type: str, data: dict):
local_config = conf()
enabled = data.get("enabled", "Y")
cred_changed = self._set_channel_credentials(
local_config, channel_type, data.get("appId"), data.get("appSecret"))
if enabled == "N":
self._remove_channel_type(local_config, channel_type)
else:
self._add_channel_type(local_config, channel_type)
self._save_config_to_file(local_config)
if not self.channel_mgr:
return
if enabled == "N":
threading.Thread(
target=self._do_remove_channel, args=(channel_type,), daemon=True
).start()
else:
existing_ch = self.channel_mgr.get_channel(channel_type)
needs_restart = cred_changed or not existing_ch
if not needs_restart and channel_type in ("weixin", "wx"):
login_status = getattr(existing_ch, "login_status", "")
if login_status != "logged_in":
needs_restart = True
logger.info(f"[CloudClient] Channel '{channel_type}' not logged in "
f"(status={login_status}), forcing restart")
if existing_ch and not needs_restart:
logger.info(f"[CloudClient] Channel '{channel_type}' already running with same config, "
"skip restart, reporting status only")
threading.Thread(
target=self._report_channel_startup, args=(channel_type,), daemon=True
).start()
else:
threading.Thread(
target=self._do_restart_channel, args=(self.channel_mgr, channel_type), daemon=True
).start()
def _handle_channel_delete(self, channel_type: str, data: dict):
local_config = conf()
self._clear_channel_credentials(local_config, channel_type)
self._remove_channel_type(local_config, channel_type)
self._save_config_to_file(local_config)
if channel_type in ("weixin", "wx"):
self._remove_weixin_credentials()
if self.channel_mgr:
threading.Thread(
target=self._do_remove_channel, args=(channel_type,), daemon=True
).start()
@staticmethod
def _remove_weixin_credentials():
"""Remove the weixin token credentials file so next connect triggers QR login."""
cred_path = os.path.expanduser(
conf().get("weixin_credentials_path", "~/.weixin_cow_credentials.json")
)
try:
if os.path.exists(cred_path):
os.remove(cred_path)
logger.info(f"[CloudClient] Removed weixin credentials: {cred_path}")
except Exception as e:
logger.warning(f"[CloudClient] Failed to remove weixin credentials: {e}")
# ------------------------------------------------------------------
# channel credentials helpers
# ------------------------------------------------------------------
@staticmethod
def _set_channel_credentials(local_config: dict, channel_type: str,
app_id, app_secret) -> bool:
"""
Write app_id / app_secret into the correct config keys for *channel_type*.
Also syncs the values to environment variables (upper-cased key) so that
skills that rely on env-based checks (e.g. has_env_var) work immediately.
Returns True if any value actually changed.
"""
cred = CREDENTIAL_MAP.get(channel_type)
if not cred:
return False
id_key, secret_key = cred
changed = False
if app_id is not None and local_config.get(id_key) != app_id:
local_config[id_key] = app_id
os.environ[id_key.upper()] = str(app_id)
changed = True
if app_secret is not None and local_config.get(secret_key) != app_secret:
local_config[secret_key] = app_secret
os.environ[secret_key.upper()] = str(app_secret)
changed = True
if changed:
logger.info(f"[CloudClient] Synced {channel_type} credentials to conf and env")
return changed
@staticmethod
def _clear_channel_credentials(local_config: dict, channel_type: str):
cred = CREDENTIAL_MAP.get(channel_type)
if not cred:
return
id_key, secret_key = cred
local_config.pop(id_key, None)
local_config.pop(secret_key, None)
os.environ.pop(id_key.upper(), None)
os.environ.pop(secret_key.upper(), None)
# ------------------------------------------------------------------
# channel_type list helpers
# ------------------------------------------------------------------
@staticmethod
def _parse_channel_types(local_config: dict) -> list:
raw = local_config.get("channel_type", "")
if isinstance(raw, list):
return [ch.strip() for ch in raw if ch.strip()]
if isinstance(raw, str):
return [ch.strip() for ch in raw.split(",") if ch.strip()]
return []
@staticmethod
def _add_channel_type(local_config: dict, channel_type: str):
types = CloudClient._parse_channel_types(local_config)
if channel_type not in types:
types.append(channel_type)
local_config["channel_type"] = ", ".join(types)
@staticmethod
def _remove_channel_type(local_config: dict, channel_type: str):
types = CloudClient._parse_channel_types(local_config)
if channel_type in types:
types.remove(channel_type)
local_config["channel_type"] = ", ".join(types)
# ------------------------------------------------------------------
# channel manager thread helpers
# ------------------------------------------------------------------
def _do_add_channel(self, channel_type: str):
try:
self.channel_mgr.add_channel(channel_type)
logger.info(f"[CloudClient] Channel '{channel_type}' added successfully")
except Exception as e:
logger.error(f"[CloudClient] Failed to add channel '{channel_type}': {e}", exc_info=True)
self.send_channel_status(channel_type, "error", str(e))
return
self._report_channel_startup(channel_type)
def _do_remove_channel(self, channel_type: str):
try:
self.channel_mgr.remove_channel(channel_type)
logger.info(f"[CloudClient] Channel '{channel_type}' removed successfully")
except Exception as e:
logger.error(f"[CloudClient] Failed to remove channel '{channel_type}': {e}")
def send_channel_qrcode(self, channel_type: str, qrcode_url: str):
"""Report QR code URL for a channel that requires scan-to-login."""
if self.client_id:
from linkai.api.client.client import ClientMsgType
msg = self._build_package(ClientMsgType.CHANNEL_STATUS)
msg["data"]["channelType"] = channel_type
msg["data"]["status"] = "qrcode"
msg["data"]["qrcodeUrl"] = qrcode_url
self._send_package(msg)
logger.info(f"[CloudClient] Sent QR code status for '{channel_type}'")
def _report_channel_startup(self, channel_type: str):
"""Wait for channel startup result and report to cloud."""
ch = self.channel_mgr.get_channel(channel_type)
if not ch:
self.send_channel_status(channel_type, "error", "channel instance not found")
return
if channel_type in ("weixin", "wx") and hasattr(ch, "login_status"):
login_status = getattr(ch, "login_status", "")
if login_status in ("waiting_scan", "scanned", "idle"):
logger.info(f"[CloudClient] Channel '{channel_type}' is waiting for QR login, "
"skip reporting connected")
return
success, error = ch.wait_startup(timeout=3)
if success:
logger.info(f"[CloudClient] Channel '{channel_type}' connected, reporting status")
self.send_channel_status(channel_type, "connected")
else:
logger.warning(f"[CloudClient] Channel '{channel_type}' startup failed: {error}")
self.send_channel_status(channel_type, "error", error)
# ------------------------------------------------------------------
# skill callback
# ------------------------------------------------------------------
def on_skill(self, data: dict) -> dict:
"""
Handle SKILL messages from the cloud console.
Delegates to SkillService.dispatch for the actual operations.
:param data: message data with 'action', 'clientId', 'payload'
:return: response dict
"""
action = data.get("action", "")
payload = data.get("payload")
logger.info(f"[CloudClient] on_skill: action={action}")
svc = self.skill_service
if svc is None:
return {"action": action, "code": 500, "message": "SkillService not available", "payload": None}
return svc.dispatch(action, payload)
# ------------------------------------------------------------------
# memory callback
# ------------------------------------------------------------------
def on_memory(self, data: dict) -> dict:
"""
Handle MEMORY messages from the cloud console.
Delegates to MemoryService.dispatch for the actual operations.
:param data: message data with 'action', 'clientId', 'payload'
:return: response dict
"""
action = data.get("action", "")
payload = data.get("payload")
logger.info(f"[CloudClient] on_memory: action={action}")
svc = self.memory_service
if svc is None:
return {"action": action, "code": 500, "message": "MemoryService not available", "payload": None}
return svc.dispatch(action, payload)
# ------------------------------------------------------------------
# chat callback
# ------------------------------------------------------------------
def on_chat(self, data: dict, send_chunk_fn):
"""
Handle CHAT messages from the cloud console.
Runs the agent in streaming mode and sends chunks back via send_chunk_fn.
:param data: message data with 'action' and 'payload' (query, session_id)
:param send_chunk_fn: callable(chunk_data: dict) to send one streaming chunk
"""
payload = data.get("payload", {})
query = payload.get("query", "")
session_id = payload.get("session_id", "cloud_console")
channel_type = payload.get("channel_type", "")
if not session_id.startswith("session_"):
session_id = f"session_{session_id}"
logger.info(f"[CloudClient] on_chat: session={session_id}, channel={channel_type}, query={query[:80]}")
svc = self.chat_service
if svc is None:
raise RuntimeError("ChatService not available")
svc.run(query=query, session_id=session_id, channel_type=channel_type, send_chunk_fn=send_chunk_fn)
# ------------------------------------------------------------------
# history callback
# ------------------------------------------------------------------
def on_history(self, data: dict) -> dict:
"""
Handle HISTORY messages from the cloud console.
Returns paginated conversation history for a session.
:param data: message data with 'action' and 'payload' (session_id, page, page_size)
:return: response dict
"""
action = data.get("action", "query")
payload = data.get("payload", {})
logger.info(f"[CloudClient] on_history: action={action}")
if action == "query":
return self._query_history(payload)
return {"action": action, "code": 404, "message": f"unknown action: {action}", "payload": None}
def _query_history(self, payload: dict) -> dict:
"""Query paginated conversation history using ConversationStore."""
session_id = payload.get("session_id", "")
page = int(payload.get("page", 1))
page_size = int(payload.get("page_size", 20))
if not session_id:
return {
"action": "query",
"payload": {"status": "error", "message": "session_id required"},
}
# Web channel stores sessions with a "session_" prefix
if not session_id.startswith("session_"):
session_id = f"session_{session_id}"
logger.info(f"[CloudClient] history query: session={session_id}, page={page}, page_size={page_size}")
try:
from agent.memory.conversation_store import get_conversation_store
store = get_conversation_store()
result = store.load_history_page(
session_id=session_id,
page=page,
page_size=page_size,
)
return {
"action": "query",
"payload": {"status": "success", **result},
}
except Exception as e:
logger.error(f"[CloudClient] History query error: {e}")
return {
"action": "query",
"payload": {"status": "error", "message": str(e)},
}
# ------------------------------------------------------------------
# channel restart helpers
# ------------------------------------------------------------------
def _restart_channel(self, new_channel_type: str):
"""
Restart the channel via ChannelManager when channel type changes.
"""
if self.channel_mgr:
logger.info(f"[CloudClient] Restarting channel to '{new_channel_type}'...")
threading.Thread(target=self._do_restart_channel, args=(self.channel_mgr, new_channel_type), daemon=True).start()
else:
logger.warning("[CloudClient] ChannelManager not available, please restart the application manually")
def _do_restart_channel(self, mgr, new_channel_type: str):
"""
Perform the channel restart in a separate thread to avoid blocking the config callback.
"""
try:
mgr.restart(new_channel_type)
if mgr.channel:
self.channel = mgr.channel
self.client_type = mgr.channel.channel_type
logger.info(f"[CloudClient] Channel reference updated to '{new_channel_type}'")
except Exception as e:
logger.error(f"[CloudClient] Channel restart failed: {e}")
self.send_channel_status(new_channel_type, "error", str(e))
return
self._report_channel_startup(new_channel_type)
# ------------------------------------------------------------------
# config persistence
# ------------------------------------------------------------------
def _save_config_to_file(self, local_config: dict):
"""
Save configuration to config.json file.
"""
try:
config_path = os.path.join(get_root(), "config.json")
if not os.path.exists(config_path):
logger.warning(f"[CloudClient] config.json not found at {config_path}, skip saving")
return
with open(config_path, "r", encoding="utf-8") as f:
file_config = json.load(f)
file_config.update(dict(local_config))
with open(config_path, "w", encoding="utf-8") as f:
json.dump(file_config, f, indent=4, ensure_ascii=False)
logger.info("[CloudClient] Configuration saved to config.json successfully")
except Exception as e:
logger.error(f"[CloudClient] Failed to save configuration to config.json: {e}")
def get_root_domain(host: str = "") -> str:
"""Extract root domain from a hostname.
If *host* is empty, reads CLOUD_HOST env var / cloud_host config.
"""
if not host:
host = os.environ.get("CLOUD_HOST") or conf().get("cloud_host", "")
if not host:
return ""
host = host.strip().rstrip("/")
if "://" in host:
host = host.split("://", 1)[1]
host = host.split("/", 1)[0].split(":")[0]
parts = host.split(".")
if len(parts) >= 2:
return ".".join(parts[-2:])
return host
def get_deployment_id() -> str:
"""Return cloud deployment id from env var or config."""
return os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id", "")
def get_website_base_url() -> str:
"""Return the public URL prefix that maps to the workspace websites/ dir.
Returns empty string when cloud deployment is not configured.
"""
deployment_id = get_deployment_id()
if not deployment_id:
return ""
websites_domain = os.environ.get("CLOUD_WEBSITES_DOMAIN") or conf().get("cloud_websites_domain", "")
if websites_domain:
websites_domain = websites_domain.strip().rstrip("/")
return f"https://{websites_domain}/{deployment_id}"
domain = get_root_domain()
if not domain:
return ""
return f"https://app.{domain}/{deployment_id}"
def build_website_prompt(workspace_dir: str) -> list:
"""Build system prompt lines for cloud website/file sharing rules.
Returns an empty list when cloud deployment is not configured,
so callers can safely do ``lines.extend(build_website_prompt(...))``.
"""
base_url = get_website_base_url()
if not base_url:
return []
return [
"**文件分享与网页生成规则** (非常重要 — 当前为云部署模式):",
"",
f"云端已为工作空间的 `websites/` 目录配置好公网路由映射,访问地址前缀为: `{base_url}`",
"",
"1. **网页/网站**: 编写网页、H5页面等前端代码时**必须**将文件放到 `websites/` 目录中",
f" - 例如: `websites/index.html` → `{base_url}/index.html`",
f" - 例如: `websites/my-app/index.html` → `{base_url}/my-app/index.html`",
"",
"2. **生成文件分享** (PPT、PDF、图片、音视频等): 当你为用户生成了需要下载或查看的文件时,**可以**将文件保存到 `websites/` 目录中",
f" - 例如: 生成的PPT保存到 `websites/files/report.pptx` → 下载链接为 `{base_url}/files/report.pptx`",
" - 你仍然可以同时使用 `send` 工具发送文件在飞书、钉钉等IM渠道中有效但**必须同时在回复文本中提供下载链接**作为兜底,因为部分渠道(如网页端)无法通过 send 接收本地文件",
"",
"3. **必须发送链接**: 无论是网页还是文件,生成后**必须将完整的访问/下载链接直接写在回复文本中发送给用户**",
"",
"4. **文件名和路径尽量使用英文/拼音/数字等**,不要使用中文,避免链接无法访问",
"",
"5. 建议为每个独立项目在 `websites/` 下创建子目录,保持结构清晰",
"",
]
def start(channel, channel_mgr=None):
if not get_deployment_id():
return
global chat_client
chat_client = CloudClient(api_key=conf().get("linkai_api_key"), host=conf().get("cloud_host", ""), channel=channel)
chat_client.channel_mgr = channel_mgr
chat_client.config = _build_config()
chat_client.start()
time.sleep(1.5)
if chat_client.client_id:
logger.info("[CloudClient] Console: https://link-ai.tech/console/clients")
if channel_mgr:
channel_mgr.cloud_mode = True
threading.Thread(target=_report_existing_channels, args=(chat_client, channel_mgr), daemon=True).start()
def _report_existing_channels(client: CloudClient, mgr):
"""Report status for all channels that were started before cloud client connected."""
try:
for name, ch in list(mgr._channels.items()):
if name == "web":
continue
ch.cloud_mode = True
client._report_channel_startup(name)
except Exception as e:
logger.warning(f"[CloudClient] Failed to report existing channel status: {e}")
def _build_config():
local_conf = conf()
config = {
"linkai_app_code": local_conf.get("linkai_app_code"),
"single_chat_prefix": local_conf.get("single_chat_prefix"),
"single_chat_reply_prefix": local_conf.get("single_chat_reply_prefix"),
"single_chat_reply_suffix": local_conf.get("single_chat_reply_suffix"),
"group_chat_prefix": local_conf.get("group_chat_prefix"),
"group_chat_reply_prefix": local_conf.get("group_chat_reply_prefix"),
"group_chat_reply_suffix": local_conf.get("group_chat_reply_suffix"),
"group_name_white_list": local_conf.get("group_name_white_list"),
"nick_name_black_list": local_conf.get("nick_name_black_list"),
"speech_recognition": "Y" if local_conf.get("speech_recognition") else "N",
"text_to_image": local_conf.get("text_to_image"),
"image_create_prefix": local_conf.get("image_create_prefix"),
"model": local_conf.get("model"),
"agent_max_context_turns": local_conf.get("agent_max_context_turns"),
"agent_max_context_tokens": local_conf.get("agent_max_context_tokens"),
"agent_max_steps": local_conf.get("agent_max_steps"),
"channelType": local_conf.get("channel_type"),
}
if local_conf.get("always_reply_voice"):
config["reply_voice_mode"] = "always_reply_voice"
elif local_conf.get("voice_reply_voice"):
config["reply_voice_mode"] = "voice_reply_voice"
if pconf("linkai"):
config["group_app_map"] = pconf("linkai").get("group_app_map")
if plugin_config.get("Godcmd"):
config["admin_password"] = plugin_config.get("Godcmd").get("password")
# Add channel-specific app credentials
current_channel_type = local_conf.get("channel_type", "")
if current_channel_type == "feishu":
config["app_id"] = local_conf.get("feishu_app_id")
config["app_secret"] = local_conf.get("feishu_app_secret")
elif current_channel_type == "dingtalk":
config["app_id"] = local_conf.get("dingtalk_client_id")
config["app_secret"] = local_conf.get("dingtalk_client_secret")
elif current_channel_type in ("wechatmp", "wechatmp_service"):
config["app_id"] = local_conf.get("wechatmp_app_id")
config["app_secret"] = local_conf.get("wechatmp_app_secret")
elif current_channel_type == "wecom_bot":
config["app_id"] = local_conf.get("wecom_bot_id")
config["app_secret"] = local_conf.get("wecom_bot_secret")
elif current_channel_type == "qq":
config["app_id"] = local_conf.get("qq_app_id")
config["app_secret"] = local_conf.get("qq_app_secret")
elif current_channel_type == "wechatcom_app":
config["app_id"] = local_conf.get("wechatcomapp_agent_id")
config["app_secret"] = local_conf.get("wechatcomapp_secret")
return config

View File

@@ -1,77 +1,107 @@
# bot_type
# 厂商类型
OPEN_AI = "openAI"
CHATGPT = "chatGPT"
BAIDU = "baidu" # 百度文心一言模型
OPENAI = "openai"
CHATGPT = "chatGPT" # legacy alias for OPENAI, kept for backward compatibility
BAIDU = "baidu"
XUNFEI = "xunfei"
CHATGPTONAZURE = "chatGPTOnAzure"
LINKAI = "linkai"
CLAUDEAI = "claude" # 使用cookie的历史模型
CLAUDEAPI= "claudeAPI" # 通过Claude api调用模型
QWEN = "qwen" # 旧版通义模型
QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key
GEMINI = "gemini" # gemini-1.0-pro
ZHIPU_AI = "glm-4"
CLAUDEAPI= "claudeAPI"
QWEN = "qwen" # 旧版千问接入
QWEN_DASHSCOPE = "dashscope" # 新版千问接入(百炼)
GEMINI = "gemini"
ZHIPU_AI = "zhipu"
MOONSHOT = "moonshot"
MiniMax = "minimax"
DEEPSEEK = "deepseek"
MODELSCOPE = "modelscope"
# model
# 模型列表
# Claude (Anthropic)
CLAUDE3 = "claude-3-opus-20240229"
GPT35 = "gpt-3.5-turbo"
GPT35_0125 = "gpt-3.5-turbo-0125"
GPT35_1106 = "gpt-3.5-turbo-1106"
GPT_4o = "gpt-4o"
GPT_4O_0806 = "gpt-4o-2024-08-06"
GPT4_TURBO = "gpt-4-turbo"
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09"
GPT4_TURBO_01_25 = "gpt-4-0125-preview"
GPT4_TURBO_11_06 = "gpt-4-1106-preview"
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
GPT4 = "gpt-4"
GPT_4o_MINI = "gpt-4o-mini"
GPT4_32k = "gpt-4-32k"
GPT4_06_13 = "gpt-4-0613"
GPT4_32k_06_13 = "gpt-4-32k-0613"
GPT_41 = "gpt-4.1"
GPT_41_MINI = "gpt-4.1-mini"
GPT_41_NANO = "gpt-4.1-nano"
GPT_5 = "gpt-5"
GPT_5_MINI = "gpt-5-mini"
GPT_5_NANO = "gpt-5-nano"
O1 = "o1-preview"
O1_MINI = "o1-mini"
WHISPER_1 = "whisper-1"
TTS_1 = "tts-1"
TTS_1_HD = "tts-1-hd"
WEN_XIN = "wenxin"
WEN_XIN_4 = "wenxin-4"
QWEN_TURBO = "qwen-turbo"
QWEN_PLUS = "qwen-plus"
QWEN_MAX = "qwen-max"
LINKAI_35 = "linkai-3.5"
LINKAI_4_TURBO = "linkai-4-turbo"
LINKAI_4o = "linkai-4o"
CLAUDE_3_OPUS = "claude-3-opus-latest"
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
CLAUDE_4_OPUS = "claude-opus-4-0"
CLAUDE_4_6_OPUS = "claude-opus-4-6" # Claude Opus 4.6 - Agent推荐模型
CLAUDE_4_SONNET = "claude-sonnet-4-0" # Claude Sonnet 4.0
CLAUDE_4_5_SONNET = "claude-sonnet-4-5" # Claude Sonnet 4.5 - Agent推荐模型
CLAUDE_4_6_SONNET = "claude-sonnet-4-6" # Claude Sonnet 4.6 - Agent推荐模型
# Gemini (Google)
GEMINI_PRO = "gemini-1.0-pro"
GEMINI_15_flash = "gemini-1.5-flash"
GEMINI_15_PRO = "gemini-1.5-pro"
GEMINI_20_flash_exp = "gemini-2.0-flash-exp" # exp结尾为实验模型会逐步不再支持
GEMINI_20_FLASH = "gemini-2.0-flash" # 正式版模型
GEMINI_25_FLASH_PRE = "gemini-2.5-flash-preview-05-20" # preview为预览版模型 ,主要是新能力体验
GEMINI_25_FLASH_PRE = "gemini-2.5-flash-preview-05-20"
GEMINI_25_PRO_PRE = "gemini-2.5-pro-preview-05-06"
GEMINI_3_FLASH_PRE = "gemini-3-flash-preview" # Gemini 3 Flash Preview - Agent推荐模型
GEMINI_3_PRO_PRE = "gemini-3-pro-preview" # Gemini 3 Pro Preview
GEMINI_31_PRO_PRE = "gemini-3.1-pro-preview" # Gemini 3.1 Pro Preview - Agent推荐模型
GEMINI_31_FLASH_LITE_PRE = "gemini-3.1-flash-lite-preview" # Gemini 3.1 Flash Lite Preview - Agent推荐模型
# OpenAI
GPT35 = "gpt-3.5-turbo"
GPT35_0125 = "gpt-3.5-turbo-0125"
GPT35_1106 = "gpt-3.5-turbo-1106"
GPT4 = "gpt-4"
GPT4_06_13 = "gpt-4-0613"
GPT4_32k = "gpt-4-32k"
GPT4_32k_06_13 = "gpt-4-32k-0613"
GPT4_TURBO = "gpt-4-turbo"
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
GPT4_TURBO_01_25 = "gpt-4-0125-preview"
GPT4_TURBO_11_06 = "gpt-4-1106-preview"
GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09"
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
GPT_4o = "gpt-4o"
GPT_4O_0806 = "gpt-4o-2024-08-06"
GPT_4o_MINI = "gpt-4o-mini"
GPT_41 = "gpt-4.1"
GPT_41_MINI = "gpt-4.1-mini"
GPT_41_NANO = "gpt-4.1-nano"
GPT_5 = "gpt-5"
GPT_5_MINI = "gpt-5-mini"
GPT_5_NANO = "gpt-5-nano"
GPT_54 = "gpt-5.4" # GPT-5.4 - Agent recommended model
GPT_54_MINI = "gpt-5.4-mini"
GPT_54_NANO = "gpt-5.4-nano"
O1 = "o1-preview"
O1_MINI = "o1-mini"
WHISPER_1 = "whisper-1"
TTS_1 = "tts-1"
TTS_1_HD = "tts-1-hd"
# DeepSeek
DEEPSEEK_CHAT = "deepseek-chat" # DeepSeek-V3对话模型
DEEPSEEK_REASONER = "deepseek-reasoner" # DeepSeek-R1模型
# Qwen (通义千问 - 阿里云)
QWEN = "qwen"
QWEN_TURBO = "qwen-turbo"
QWEN_PLUS = "qwen-plus"
QWEN_MAX = "qwen-max"
QWEN_LONG = "qwen-long"
QWEN3_MAX = "qwen3-max" # Qwen3 Max - Agent推荐模型
QWEN35_PLUS = "qwen3.5-plus" # Qwen3.5 Plus - Omni model (MultiModalConversation)
QWQ_PLUS = "qwq-plus"
# MiniMax
MINIMAX_M2_7 = "MiniMax-M2.7" # MiniMax M2.7 - Latest
MINIMAX_M2_5 = "MiniMax-M2.5" # MiniMax M2.5
MINIMAX_M2_1 = "MiniMax-M2.1" # MiniMax M2.1
MINIMAX_M2_1_LIGHTNING = "MiniMax-M2.1-lightning" # MiniMax M2.1 极速版
MINIMAX_M2 = "MiniMax-M2" # MiniMax M2
MINIMAX_ABAB6_5 = "abab6.5-chat" # MiniMax abab6.5
# GLM (智谱AI)
GLM_5_TURBO = "glm-5-turbo" # 智谱 GLM-5-Turbo - Latest
GLM_5 = "glm-5" # 智谱 GLM-5
GLM_4 = "glm-4"
GLM_4_PLUS = "glm-4-plus"
GLM_4_flash = "glm-4-flash"
@@ -80,20 +110,28 @@ GLM_4_ALLTOOLS = "glm-4-alltools"
GLM_4_0520 = "glm-4-0520"
GLM_4_AIR = "glm-4-air"
GLM_4_AIRX = "glm-4-airx"
GLM_4_7 = "glm-4.7" # 智谱 GLM-4.7 - Agent推荐模型
# Kimi (Moonshot)
MOONSHOT = "moonshot"
KIMI_K2 = "kimi-k2"
KIMI_K2_5 = "kimi-k2.5"
CLAUDE_3_OPUS = "claude-3-opus-latest"
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
CLAUDE_4_SONNET = "claude-sonnet-4-0"
CLAUDE_4_OPUS = "claude-opus-4-0"
# Doubao (Volcengine Ark)
DOUBAO = "doubao"
DOUBAO_SEED_2_CODE = "doubao-seed-2-0-code-preview-260215"
DOUBAO_SEED_2_PRO = "doubao-seed-2-0-pro-260215"
DOUBAO_SEED_2_LITE = "doubao-seed-2-0-lite-260215"
DOUBAO_SEED_2_MINI = "doubao-seed-2-0-mini-260215"
DEEPSEEK_CHAT = "deepseek-chat" # DeepSeek-V3对话模型
DEEPSEEK_REASONER = "deepseek-reasoner" # DeepSeek-R1模型
# 其他模型
WEN_XIN = "wenxin"
WEN_XIN_4 = "wenxin-4"
XUNFEI = "xunfei"
LINKAI_35 = "linkai-3.5"
LINKAI_4_TURBO = "linkai-4-turbo"
LINKAI_4o = "linkai-4o"
MODELSCOPE = "modelscope"
GITEE_AI_MODEL_LIST = ["Yi-34B-Chat", "InternVL2-8B", "deepseek-coder-33B-instruct", "InternVL2.5-26B", "Qwen2-VL-72B", "Qwen2.5-32B-Instruct", "glm-4-9b-chat", "codegeex4-all-9b", "Qwen2.5-Coder-32B-Instruct", "Qwen2.5-72B-Instruct", "Qwen2.5-7B-Instruct", "Qwen2-72B-Instruct", "Qwen2-7B-Instruct", "code-raccoon-v1", "Qwen2.5-14B-Instruct"]
@@ -104,19 +142,48 @@ MODELSCOPE_MODEL_LIST = ["LLM-Research/c4ai-command-r-plus-08-2024","mistralai/M
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B","deepseek-ai/DeepSeek-R1-Distill-Qwen-7B","deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B","deepseek-ai/DeepSeek-R1","deepseek-ai/DeepSeek-V3","Qwen/QwQ-32B"]
MODEL_LIST = [
# Claude
CLAUDE3, CLAUDE_4_6_SONNET, CLAUDE_4_6_OPUS, CLAUDE_4_OPUS, CLAUDE_4_5_SONNET, CLAUDE_4_SONNET, CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229,
CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU,
"claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
# Gemini
GEMINI_31_FLASH_LITE_PRE, GEMINI_31_PRO_PRE, GEMINI_3_PRO_PRE, GEMINI_3_FLASH_PRE, GEMINI_25_PRO_PRE, GEMINI_25_FLASH_PRE,
GEMINI_20_FLASH, GEMINI_20_flash_exp, GEMINI_15_PRO, GEMINI_15_flash, GEMINI_PRO, GEMINI,
# OpenAI
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
GPT_41, GPT_41_MINI, GPT_41_NANO, O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
GPT4, GPT4_06_13, GPT4_32k, GPT4_32k_06_13,
GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4_TURBO_04_09,
GPT_4o, GPT_4O_0806, GPT_4o_MINI,
GPT_41, GPT_41_MINI, GPT_41_NANO,
GPT_5, GPT_5_MINI, GPT_5_NANO,
WEN_XIN, WEN_XIN_4,
XUNFEI,
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
MOONSHOT, MiniMax,
GEMINI_25_PRO_PRE, GEMINI_25_FLASH_PRE, GEMINI_20_FLASH, GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO, GEMINI_20_flash_exp,
CLAUDE_4_OPUS, CLAUDE_4_SONNET, CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o,
GPT_54, GPT_54_MINI, GPT_54_NANO,
O1, O1_MINI,
# DeepSeek
DEEPSEEK_CHAT, DEEPSEEK_REASONER,
# Qwen
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX, QWEN_LONG, QWEN3_MAX, QWEN35_PLUS,
# MiniMax
MiniMax, MINIMAX_M2_7, MINIMAX_M2_5, MINIMAX_M2_1, MINIMAX_M2_1_LIGHTNING, MINIMAX_M2, MINIMAX_ABAB6_5,
# GLM
ZHIPU_AI, GLM_5_TURBO, GLM_5, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS,
GLM_4_0520, GLM_4_AIR, GLM_4_AIRX, GLM_4_7,
# Kimi
MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
KIMI_K2, KIMI_K2_5,
# Doubao
DOUBAO, DOUBAO_SEED_2_CODE, DOUBAO_SEED_2_PRO, DOUBAO_SEED_2_LITE, DOUBAO_SEED_2_MINI,
# 其他模型
WEN_XIN, WEN_XIN_4, XUNFEI,
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o,
MODELSCOPE
]
@@ -124,3 +191,6 @@ MODEL_LIST = MODEL_LIST + GITEE_AI_MODEL_LIST + MODELSCOPE_MODEL_LIST
# channel
FEISHU = "feishu"
DINGTALK = "dingtalk"
WECOM_BOT = "wecom_bot"
QQ = "qq"
WEIXIN = "weixin"

Some files were not shown because too many files have changed in this diff Show More