Compare commits

..

802 Commits

Author SHA1 Message Date
zhayujie
66b71c50e9 feat(wecom_bot): add Wecom Bot QR code scan auth 2026-03-31 21:27:50 +08:00
zhayujie
8744810b25 fix: skill install timeout 2026-03-31 20:47:59 +08:00
zhayujie
7f94d37c2e fix: auto-install font in browser 2026-03-31 20:20:13 +08:00
zhayujie
6d9b7baeb4 fix(weixin): file send failed 2026-03-31 18:14:49 +08:00
zhayujie
4470d4c352 fix: reduce docker image size 2026-03-31 16:56:27 +08:00
zhayujie
d2a462a279 fix: add apt source in docker file 2026-03-31 16:34:47 +08:00
zhayujie
14ff2a15e7 fix(cli): cow cli in docker chat 2026-03-31 16:25:47 +08:00
zhayujie
6d1369900e feat: add source args in docker building 2026-03-31 16:06:45 +08:00
zhayujie
1f17ebe69e feat: add browser install in docker image 2026-03-31 16:05:05 +08:00
zhayujie
1ae2918064 feat: support install browser in chat 2026-03-31 15:15:17 +08:00
zhayujie
b6571e5cad fix: browser resource optimization 2026-03-30 21:39:38 +08:00
zhayujie
7549d48cf1 fix: browser thread bug 2026-03-30 21:27:08 +08:00
zhayujie
00353dd0cb feat: support skill hub mirror 2026-03-30 18:46:02 +08:00
zhayujie
afd947195d fix(cli): support skill mirror install 2026-03-30 16:36:17 +08:00
zhayujie
e57ef37167 fix: prevent phantom mouseover from hijacking slash menu 2026-03-30 11:52:05 +08:00
zhayujie
ef33a93654 Merge pull request #2731 from zkjqd/fix/slash-menu-click
Fix the issue where the shortcut command in the input box cannot be clicked to select events
2026-03-30 11:40:06 +08:00
zhayujie
61732aecfc Merge pull request #2721 from yrk111222/feat/modelscope-update
Feat/modelscope update
2026-03-30 11:39:50 +08:00
zkjqd
6764c05c3f input-slash-click 2026-03-30 11:20:03 +08:00
zhayujie
fa149cf4aa fix(browser): multi-thread browser instance bug 2026-03-30 00:57:19 +08:00
zhayujie
e4f9697d06 feat(browser): install font in linux 2026-03-29 23:52:51 +08:00
zhayujie
da061450e5 fix: github skill install cmd 2026-03-29 19:23:47 +08:00
zhayujie
d09ae49287 feat(browser): auto-snapshot on navigate, screenshot prompt guidance
Browser tool enhancements:
- Navigate action now auto-includes snapshot result, saving one LLM round-trip
- Wait for networkidle + 800ms after navigation for SPA/JS-rendered pages
- Prompt guides agent to screenshot key results and ask user for login/CAPTCHA help
- Fixed playwright version pinned to 1.52.0; mirror fallback to official CDN on failure

Web console file/image support:
- SSE real-time push for images and files via on_event (file_to_send)
- Added /api/file endpoint to serve local files for web preview
- Frontend renders images in media-content container (survives delta/done overwrites)
- File attachment cards with download links; RFC 5987 encoding for non-ASCII filenames

Tool workspace fix:
- Inject workspace_dir as cwd into send and browser tools (previously only file tools)
- Screenshots now save to ~/cow/tmp/ instead of project directory
2026-03-29 19:09:11 +08:00
zhayujie
511ee0bbaf fix: windows PowerShell script 2026-03-29 18:28:50 +08:00
zhayujie
3cb5a0fbd6 docs: add CLI system docs 2026-03-29 17:57:12 +08:00
zhayujie
e06925ab85 fix: optimize browser install cli and fix vision prompt 2026-03-29 15:19:59 +08:00
zhayujie
184634e4e7 fix(cli): browser install failed 2026-03-29 15:14:07 +08:00
zhayujie
843c2d02cc Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-03-29 15:09:37 +08:00
zhayujie
8ea2455766 feat(cli): add browser install cmd 2026-03-29 15:09:07 +08:00
zhayujie
9dc9987d56 Merge pull request #2727 from zhayujie/feat-browser-tool
feat: add browser tool
2026-03-29 14:59:39 +08:00
zhayujie
3458621147 feat: add browser tool 2026-03-29 14:59:06 +08:00
zhayujie
079df5a47c feat: support batch skill install from zip and github 2026-03-29 14:38:11 +08:00
zhayujie
ddb07c65a1 feat: support github zip-first download, gitLab, git@ ssh, local path 2026-03-29 13:45:15 +08:00
zhayujie
9b21cd222b fix: update run.sh 2026-03-28 19:36:51 +08:00
zhayujie
90f736843f fix: add click dependencies 2026-03-28 19:35:15 +08:00
zhayujie
13c020eb61 fix(cli): cli output in wecom_bot 2026-03-28 19:26:59 +08:00
zhayujie
dbc06dbe95 fix: use new run.sh when updating 2026-03-28 19:16:41 +08:00
zhayujie
23d097bc1c Merge pull request #2726 from zhayujie/feat-cow-cli
feat: cow cli in terminal and chat
2026-03-28 19:01:56 +08:00
zhayujie
db85b9808e feat(cli): add cow update 2026-03-28 18:58:42 +08:00
zhayujie
df5bae37bc feat: add MiniMax-M2.7 and glm-5-turbo in web console 2026-03-28 18:48:11 +08:00
zhayujie
acc23b6051 feat: optimize agent prompt and fix skill source load 2026-03-28 18:37:07 +08:00
zhayujie
61f2741afc feat: organize skill source field 2026-03-28 17:41:40 +08:00
zhayujie
4dd7ea886a feat(cli): cli options in web console 2026-03-28 16:26:41 +08:00
zhayujie
1e8959fbcf fix: optimize repo clone in run.sh 2026-03-28 15:08:57 +08:00
zhayujie
48729678cf Merge branch 'master' into feat-cow-cli 2026-03-28 14:47:20 +08:00
zhayujie
0684becaa7 fix(cli): register skill when installing 2026-03-28 14:42:18 +08:00
zhayujie
db16bdf8cb fix(cli): add security hardening for skill install and process management 2026-03-27 17:59:15 +08:00
zhayujie
f890318ed9 fix: strip leading/trailing whitespace from agent response 2026-03-26 18:13:39 +08:00
zhayujie
158510cbbe feat(cli): imporve cow cli and skill hub integration 2026-03-26 16:49:42 +08:00
zhayujie
ce90cf7aa8 fix: weixin cdn upload retry 2026-03-26 10:20:29 +08:00
zhayujie
a3a3d006eb Merge pull request #2723 from Xiaozhou345/Xiaozhou345-fix-readme-spacing
优化 README 中的中英文排版空格
2026-03-26 10:14:27 +08:00
zhayujie
8fd029a4a1 feat(cli): support cow cli 2026-03-26 10:08:51 +08:00
Xiaozhou345
2e1b52c1e5 优化 README 中的中英文排版空格
按照中文技术文档规范,在文件名和中文之间增加了空格,提升可读性。
2026-03-25 21:26:01 +08:00
zhayujie
3eb8348708 fix: docker volume permission issue and clean up unused dependencies 2026-03-25 01:25:34 +08:00
zhayujie
393f0c007c fix: context loss after trim 2026-03-24 20:49:28 +08:00
yrk
294e380288 update model_list 2026-03-24 11:00:55 +08:00
yrk
4c1c42efac feat: update modelscope bot 2026-03-24 10:43:45 +08:00
zhayujie
c062ca8c66 Merge pull request #2720 from 6vision/fix/deepseek-docs
Docs: update
2026-03-24 00:25:17 +08:00
6vision
76dcb25103 docs(deepseek): update model descriptions to V3.2 with thinking/non-thinking mode
Made-with: Cursor
2026-03-24 00:05:39 +08:00
6vision
c5b4f236db docs(deepseek): remove migration notes from zh and en docs
Made-with: Cursor
2026-03-24 00:05:39 +08:00
zhayujie
0974c940a8 Merge pull request #2719 from 6vision/feat/deepseek-bot
feat: add independent DeepSeek bot module with dedicated config
2026-03-23 22:42:58 +08:00
6vision
cffa20d37e docs(deepseek): remove migration notes to reduce user cognitive load
Made-with: Cursor
2026-03-23 22:39:15 +08:00
6vision
ef009edd29 docs(deepseek): update config guides for independent DeepSeek module
Update DeepSeek docs (zh/en/ja) and README to reflect the new dedicated deepseek_api_key / deepseek_api_base config fields, with backward compatibility notes.

Made-with: Cursor
2026-03-23 21:43:51 +08:00
zhayujie
3ca52b118d fix(weixin): qrcode url log 2026-03-23 21:33:53 +08:00
zhayujie
13f5fde4fb fix: rebuild system prompt from scratch on every turn 2026-03-23 21:27:44 +08:00
6vision
f512b55ec2 feat(deepseek): add independent DeepSeek bot module with dedicated config
Separate DeepSeek from ChatGPTBot into its own module (models/deepseek/) with dedicated deepseek_api_key and deepseek_api_base config fields, avoiding config conflicts when switching between providers. Backward compatible with old users who configured DeepSeek via open_ai_api_key/open_ai_api_base through automatic fallback.

Made-with: Cursor
2026-03-23 21:23:35 +08:00
zhayujie
22b8ca0095 feat: optimize vision image compression 2026-03-23 21:18:04 +08:00
zhayujie
baf66a103d fix(weixin): preserve original filename for received files 2026-03-23 01:18:02 +08:00
zhayujie
45faa9c1ff fix(wexin): resolve image/file send and receive failures 2026-03-23 00:13:41 +08:00
zhayujie
304381a88d fix: hide breadcrumb on mobile for better space utilization 2026-03-22 23:36:34 +08:00
zhayujie
fc9f54dbc8 feat(weixin): optimize login qrcode generate 2026-03-22 23:04:50 +08:00
zhayujie
7199dc187f fix: default gemini model 2026-03-22 22:52:37 +08:00
zhayujie
e9ae066d53 Merge pull request #2716 from cowagent/fix-gemini-model-attribute
fix: add missing model property to GoogleGeminiBot
2026-03-22 22:49:00 +08:00
cowagent
d71ae406ff fix: add missing model property to GoogleGeminiBot
api_key and api_base were refactored to @property but model was not
migrated, causing AttributeError: 'GoogleGeminiBot' object has no
attribute 'model' when using any Gemini model.
2026-03-22 22:43:26 +08:00
zhayujie
f3216904b3 feat(weixin): optimize weixin login qrcode 2026-03-22 21:34:47 +08:00
zhayujie
5958b69ec9 feat: release 2.0.4 2026-03-22 20:49:41 +08:00
zhayujie
7d4e2cb39a docs: update comments 2026-03-22 19:07:19 +08:00
zhayujie
a483ec0cea feat: optimize weixin channel qr code generate 2026-03-22 18:20:10 +08:00
zhayujie
c1421e0874 feat: support weixin channel in scripts 2026-03-22 16:29:12 +08:00
zhayujie
ce89869c3c feat: support weixin channel 2026-03-22 15:52:13 +08:00
zhayujie
b8b57e34ff fix: auto-repair messages 2026-03-21 14:20:22 +08:00
zhayujie
bc7f627253 fix(wecom_bot): compat with old websocket-client 2026-03-21 14:03:17 +08:00
zhayujie
652156e398 feat: make run.sh executable 2026-03-20 17:56:10 +08:00
zhayujie
9febb071c6 fix: run.sh get pid bug 2026-03-20 17:51:04 +08:00
zhayujie
7d0e1568ac fix: feishu msg and log encoding 2026-03-19 17:07:39 +08:00
zhayujie
b4e711f411 feat: add request header 2026-03-19 17:06:05 +08:00
zhayujie
1b5be1b981 fix: remove feishu_bot_name in run.sh 2026-03-19 14:55:12 +08:00
zhayujie
49d8707c58 refactor: simplify run.sh by extracting shared logic and eliminating duplication 2026-03-19 11:07:16 +08:00
zhayujie
9192f6f7f7 feat: add MiniMax-M2.7 and glm-5-turbo 2026-03-19 10:46:13 +08:00
zhayujie
05022e3745 fix: add log 2026-03-18 23:09:27 +08:00
zhayujie
5356e9ddeb docs: adjust docs order 2026-03-18 21:55:09 +08:00
zhayujie
52acf76e2c docs: update jp docs 2026-03-18 21:01:02 +08:00
zhayujie
40cdbd3b45 Merge pull request #2710 from eltociear/add-ja-doc
docs: add Japanese documents
2026-03-18 19:28:04 +08:00
Ikko Ashimine
5487c0befe docs: add Japanese documents 2026-03-18 19:13:39 +09:00
zhayujie
8bb16c48c0 docs: update install cmd 2026-03-18 16:11:35 +08:00
zhayujie
c6384363f9 feat: workspace volume in docker deploy 2026-03-18 16:03:03 +08:00
zhayujie
8993e8ad3e feat: release 2.0.3 2026-03-18 15:40:49 +08:00
zhayujie
289989d9f7 feat: release 2.0.3 2026-03-18 15:10:21 +08:00
zhayujie
dc2ae0e6f1 feat: support gpt-5.4-mini and gpt-5.4-nano 2026-03-18 14:55:29 +08:00
zhayujie
9c966c152d feat: enhance AGENT.md update prompts to encourage proactive evolution 2026-03-18 12:10:45 +08:00
zhayujie
4efae41048 feat: support coding plan 2026-03-18 11:59:22 +08:00
zhayujie
b8437032e9 fix: optimize image recognition prompts 2026-03-18 10:10:23 +08:00
zhayujie
2d339ca81b Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-03-17 23:03:05 +08:00
zhayujie
d53abc9696 docs: update README.md 2026-03-17 23:02:41 +08:00
zhayujie
446c886d38 Merge pull request #2706 from zhayujie/feat-web-files
feat: support files upload in web console and office parsing
2026-03-17 21:22:38 +08:00
zhayujie
30c6d9b5ae feat: support file and image upload in web console, add office docs parsing in read tool 2026-03-17 21:21:03 +08:00
zhayujie
5e42996b36 fix: guide LLM to use matching skill when tool not found 2026-03-17 18:34:09 +08:00
zhayujie
ceca7b85bf Merge pull request #2705 from zhayujie/feat-qq-channel
feat: add qq channel
2026-03-17 17:26:39 +08:00
zhayujie
a4d54f58c8 feat: complete the QQ channel and supplement the docs 2026-03-17 17:25:36 +08:00
zhayujie
005a0e1bad feat: add qq channel 2026-03-17 15:43:04 +08:00
zhayujie
46d97fd57d feat: channel config set to env 2026-03-17 11:36:20 +08:00
zhayujie
72a26b6353 fix: scheduler auto clean 2026-03-17 11:29:21 +08:00
zhayujie
89a4033fbf fix: web console bot_type 2026-03-17 10:47:41 +08:00
zhayujie
39a5dc64bd Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-03-16 19:07:54 +08:00
zhayujie
d4bdd9b1b7 docs: update README.md for wecom_bot channel 2026-03-16 19:07:08 +08:00
zhayujie
2f5ba87280 Merge pull request #2698 from zhayujie/feat-wecom-bot
feat: wecom_bot channel
2026-03-16 19:04:52 +08:00
zhayujie
8b45d6c750 docs: wecom_bot integration docs 2026-03-16 19:03:18 +08:00
zhayujie
4ecd4df2d4 feat: web console support wecom_bot config 2026-03-16 17:56:59 +08:00
zhayujie
a42f31fe52 feat: support wecom_bot stream card 2026-03-16 17:46:05 +08:00
zhayujie
d4480b695e feat(channel): add wecom_bot channel 2026-03-16 14:39:15 +08:00
zhayujie
c4b5f7fbae refactor: remove unavailable channels 2026-03-16 11:05:45 +08:00
zhayujie
ba915f2cc0 feat: add gemini-3.1-flash-lite-preview and gpt-5.4 2026-03-15 22:06:12 +08:00
zhayujie
4b91140f31 fix: optimize msg receive 2026-03-12 20:49:36 +08:00
zhayujie
9879878dd0 fix: concurrency issue in session 2026-03-12 17:08:09 +08:00
zhayujie
d78105d57c fix: tool call match 2026-03-12 17:05:27 +08:00
zhayujie
153c9e3565 fix(memory): remove useless prompt 2026-03-12 15:29:58 +08:00
zhayujie
c11623596d fix(memory): prevent context memory loss by improving trim strategy 2026-03-12 15:25:46 +08:00
zhayujie
e791a77f77 fix: strengthen bootstrap flow 2026-03-12 12:13:05 +08:00
zhayujie
b641bffb2c fix(feishu): remove bot_name dependency for group chat 2026-03-12 11:30:42 +08:00
zhayujie
ee0c47ac1e feat: file send prompt 2026-03-12 00:11:34 +08:00
zhayujie
eba90e9343 fix: workspace bootstrap 2026-03-11 23:35:42 +08:00
zhayujie
d8374d0fa5 fix: web_fetch encoding 2026-03-11 19:42:37 +08:00
zhayujie
fa61744c6d feat(web_fetch): support downloading and parsing remote document files (PDF, Word, Excel, PPT) 2026-03-11 17:47:15 +08:00
zhayujie
4fec55cc01 feat: web_featch tool support remote file url 2026-03-11 17:16:39 +08:00
zhayujie
1767413712 fix: increase minimax max_tokens 2026-03-11 15:31:35 +08:00
zhayujie
734c8fa84f fix: optimize skill prompt 2026-03-11 12:40:37 +08:00
zhayujie
9a8d422554 feat: package skill install 2026-03-11 12:18:36 +08:00
zhayujie
b21e945c76 feat: optimize bootstrap flow 2026-03-11 11:27:08 +08:00
zhayujie
a02bf1ea09 Merge pull request #2693 from 6vision/fix/bot-type-and-web-config
fix: rename zhipu bot_type, persist bot_type in web config, fix re.syb escape error
2026-03-11 10:24:19 +08:00
zhayujie
eda82bac92 fix: gemini tool call bug 2026-03-11 02:04:09 +08:00
zhayujie
e8d4f7dc4f fix: remove useless file 2026-03-10 22:56:00 +08:00
6vision
c4a93b7789 fix: rename zhipu bot_type, persist bot_type in web config, fix re.sub escape error
- Rename ZHIPU_AI bot type from glm-4 to zhipu to avoid confusion with model names

- Add bot_type persistence in web config to fix provider dropdown resetting on refresh

- Change OpenAI provider key to chatGPT to match bot_factory routing

- Add DEEPSEEK constant and route it to ChatGPTBot (OpenAI-compatible API)

- Keep backward compatibility for legacy bot_type glm-4 in bot_factory

- Fix re.sub bad escape error on Windows paths by using lambda replacement

- Remove unused pydantic import in minimax_bot.py

Made-with: Cursor
2026-03-10 21:34:24 +08:00
zhayujie
c3f9925097 fix: remove injected max-steps prompt from persisted conversation history 2026-03-10 20:08:59 +08:00
zhayujie
2a0cf7511a Merge pull request #2692 from 6vision/master
update:Adjust bot_type resolution priority in Agent mode
2026-03-10 15:17:22 +08:00
6vision
d0a70d3339 update:Adjust bot_type resolution priority in Agent mode 2026-03-10 15:14:01 +08:00
zhayujie
f37e4675dd Merge pull request #2691 from Weikjssss/fix-bot-type-conf
fix: pass bot_type in agent mode
2026-03-10 15:00:04 +08:00
zhayujie
4e32f67eeb fix: validate tool_call_id pairing #2690 2026-03-10 14:52:07 +08:00
Weikjssss
36d54cab52 fix: pass bot_type in agent mode 2026-03-10 14:28:39 +08:00
zhayujie
9d8df10dcf feat: clarify send tool is local-only 2026-03-10 12:10:10 +08:00
zhayujie
45ea88e070 Merge pull request #2689 from cowagent/fix/openai-compat-complete
fix: complete openai_compat migration across all model bots (openai>=1.0 compatibility)
2026-03-10 10:10:58 +08:00
cowagent
d5d0b947f5 fix: complete openai_compat migration across all model bots
Replace all direct openai.error.* usages with the openai_compat
compatibility layer to support openai>=1.0.

Affected files:
- models/chatgpt/chat_gpt_bot.py: fix isinstance checks (RateLimitError, Timeout, APIError, APIConnectionError)
- models/openai/open_ai_bot.py: replace import + fix isinstance checks
- models/ali/ali_qwen_bot.py: replace import + fix isinstance checks
- models/modelscope/modelscope_bot.py: remove unused openai.error import

The openai_compat layer (models/openai/openai_compat.py) already
handles both openai<1.0 and openai>=1.0 gracefully. This completes
the migration started in the existing PR #2688.
2026-03-10 10:06:04 +08:00
zhayujie
f775f1f11e Merge pull request #2688 from JasonOA888/fix/openai-compat
fix: use openai_compat layer for error handling (openai>=1.0 compatibility)
2026-03-10 10:02:41 +08:00
JasonOA888
f1e888f3de fix: use openai_compat layer for error handling
The code was directly importing openai.error which fails with openai>=1.0.
The project already has an openai_compat.py compatibility layer that handles
both old (<1.0) and new (>=1.0) OpenAI SDK versions.

This commit updates chat_gpt_bot.py to use the compatibility layer.

Related: #2687
2026-03-10 00:33:45 +08:00
zhayujie
71c8436e90 fix: skill download to temp dir 2026-03-09 18:43:28 +08:00
zhayujie
08c69f5e9b fix: clean existing skill directory before remote install to ensure full overwrite 2026-03-09 17:23:09 +08:00
zhayujie
a50fafaca2 refactor: convert image vision from skill to native tool 2026-03-09 16:01:56 +08:00
zhayujie
3c6781d240 refactor: inline skill-creator reference files into SKILL.md 2026-03-09 12:02:52 +08:00
zhayujie
3b8b5625f8 feat: add image vision provider 2026-03-09 11:37:45 +08:00
zhayujie
6be2034110 feat: add fallback embedding provider 2026-03-09 11:03:31 +08:00
zhayujie
924dc79f00 perf: lazy import to avoid 4-10s startup delay 2026-03-09 10:21:58 +08:00
zhayujie
ccb9030d3c refactor: convert web-fetch from skill to native tool 2026-03-09 10:13:48 +08:00
zhayujie
8623287ac1 docs: update memory system docs 2026-03-08 22:06:28 +08:00
zhayujie
022c13f3a4 feat: upgrade memory flush system
- Use LLM to summarize discarded context into concise daily memory entries
- Batch trim to half when exceeding max_turns/max_tokens, reducing flush frequency
- Run summarization asynchronously in background thread, no blocking on replies
- Add daily scheduled flush (23:55) as fallback for low-activity days
- Sync trimmed messages back to agent to keep context state consistent
2026-03-08 21:56:12 +08:00
zhayujie
0687916e7f fix: Safari IME enter key triggering message send
Made-with: Cursor
2026-03-08 13:21:31 +08:00
zhayujie
bb868b83ba feat: add chat history query 2026-03-08 13:03:27 +08:00
zhayujie
24298130b9 fix: minimax tool_id missing 2026-03-06 18:42:03 +08:00
zhayujie
6e5ee92ebd docs: add gpt-5.4 2026-03-06 12:25:50 +08:00
zhayujie
5b91fe04aa fix: send tool process url 2026-03-06 12:22:22 +08:00
zhayujie
1623deb3ee feat: support gpt-5.4 2026-03-06 12:04:40 +08:00
zhayujie
4a16e05b7a fix: rebuild skills when installing 2026-03-05 21:11:34 +08:00
zhayujie
f1c04bc60d feat: improve channel connection stability 2026-03-05 15:55:16 +08:00
zhayujie
84c6f31c76 fix: update agent skill metadata 2026-03-03 18:16:42 +08:00
zhayujie
9d528190bf feat: add skill category 2026-03-03 16:06:37 +08:00
zhayujie
0f23b209ad fix: adjust the context of restart loading 2026-03-03 11:38:14 +08:00
zhayujie
63d9325900 Merge pull request #2683 from pelioo/master
更新.gitignore文件添加python目录忽略规则
2026-03-01 19:41:27 +08:00
peli
f342097f81 Merge remote-tracking branch 'upstream/master' 2026-03-01 00:24:14 +08:00
zhayujie
b4806c4366 fix: model provider config 2026-02-28 18:35:04 +08:00
zhayujie
ff37d8a577 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-28 18:10:55 +08:00
zhayujie
a773eb7893 fix: filter history to one user and one assistant per turn 2026-02-28 18:09:02 +08:00
zhayujie
7c67513d24 fix: convert bash-style $VAR to %VAR% on Windows 2026-02-28 18:02:06 +08:00
zhayujie
6ed85029c5 fix: agent skills 2026-02-28 16:46:49 +08:00
zhayujie
e9c57ddf4d fix: adjust default turns 2026-02-28 15:25:20 +08:00
zhayujie
a33ce97ed9 fix: restore only user/assistant text from history, strip tool calls
Made-with: Cursor
2026-02-28 15:14:56 +08:00
zhayujie
b788a3dd4e fix: incomplete historical session messages 2026-02-28 15:03:33 +08:00
zhayujie
fccfa92d7e docs: update channel docs 2026-02-28 14:50:55 +08:00
zhayujie
8705bf0a70 feat: update docs 2026-02-28 10:53:16 +08:00
peli
9318138af7 ```
build(env): 更新.gitignore文件添加python目录忽略规则

在.gitignore文件中新增了python目录的忽略配置,
避免将Python环境相关文件提交到版本控制系统中。
```
2026-02-27 23:49:35 +08:00
zhayujie
269fa7d2d5 feat: 2.0.2 en docs 2026-02-27 18:37:22 +08:00
zhayujie
e99837a8b9 feat: release 2.0.2 2026-02-27 18:04:00 +08:00
zhayujie
553861a2c4 docs: update README.md 2026-02-27 16:57:18 +08:00
zhayujie
628a85d1be docs: update README.md 2026-02-27 16:48:23 +08:00
zhayujie
2cb54514a4 Merge pull request #2681 from zhayujie/feat-docs
feat: docs update
2026-02-27 16:04:17 +08:00
zhayujie
6db22827f2 feat: docs update 2026-02-27 16:03:47 +08:00
zhayujie
4cc6d5426b Merge pull request #2680 from zhayujie/feat-web-config
feat: web console config
2026-02-27 14:40:44 +08:00
zhayujie
7d258b5202 feat(channels): add multi-channel management UI with real-time connect/disconnect
- Web console Channels page: display active channels as config cards, support
  save/connect/disconnect with real-time start/stop of channel processes
- Custom dropdown for channel selection (consistent with model selector style),
  custom confirmation dialog for disconnect
- Fix channel stop: use sys.modules['__main__'] to access live ChannelManager
- Fix web request pending: move stop logic outside lock, set daemon_threads=True
- Fix reconnect: new asyncio event loop per startup, ctypes thread interrupt,
  5s grace period before re-establishing remote connection
- Filter stale offline messages (>60s) pushed after reconnect
2026-02-27 14:39:40 +08:00
zhayujie
c8d19ee0bc Merge pull request #2679 from zhayujie/feat-docs
docs: init docs
2026-02-27 12:14:37 +08:00
zhayujie
d891312032 docs: init docs 2026-02-27 12:10:16 +08:00
zhayujie
5edbf4ce32 feat: model and agent config in web console 2026-02-26 21:01:37 +08:00
zhayujie
3ddbdd713d Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-26 18:57:43 +08:00
zhayujie
9ba107b511 Merge branch 'feat-multi-channel' 2026-02-26 18:57:19 +08:00
zhayujie
c9adddb76a fix: pass channel_type correctly in multi-channel mode 2026-02-26 18:57:08 +08:00
zhayujie
f0a12d5ff5 Merge pull request #2678 from zhayujie/feat-multi-channel
feat: support multi-channel
2026-02-26 18:34:48 +08:00
zhayujie
7cce224499 feat: support multi-channel 2026-02-26 18:34:08 +08:00
zhayujie
97397ca585 Merge pull request #2674 from haosenwang1018/fix/bare-excepts
fix: replace 29 bare except clauses with except Exception
2026-02-26 12:11:49 +08:00
zhayujie
f2fbc602a8 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-26 10:45:01 +08:00
zhayujie
925d728a86 fix: replace upsert syntax to support SQLite lower version 2026-02-26 10:44:04 +08:00
zhayujie
f5f229871b Merge pull request #2676 from zhayujie/feat-multi-channel
feat: improve web console and conversation store
2026-02-26 10:37:03 +08:00
zhayujie
9917552b4b fix: improve web UI stability and conversation history restore
- Fix dark mode FOUC: apply theme in <head> before first paint, defer
  transition-colors to post-init to avoid animated flash on load
- Fix Safari IME Enter bug: defer compositionend reset via setTimeout(0)
- Fix history scroll: use requestAnimationFrame before scrollChatToBottom
- Limit restore turns to min(6, max_turns//3) on restart
- Fix load_messages cutoff to start at turn boundary, preventing orphaned
  tool_use/tool_result pairs from being sent to the LLM
- Merge all assistant messages within one user turn into a single bubble;
  render tool_calls in history using same CSS as live SSE view
- Handle empty choices list in stream chunks
2026-02-26 10:35:20 +08:00
haosenwang1018
adca89b973 fix: replace bare except clauses with except Exception
Bare `except:` catches BaseException including KeyboardInterrupt and
SystemExit. Replaced 29 instances with `except Exception:`.
2026-02-25 11:49:19 +00:00
zhayujie
29bfbecdc9 feat: persistent storage of conversation history 2026-02-25 18:01:39 +08:00
zhayujie
1a7a8c98d9 docs: add scam warning disclaimer 2026-02-25 01:34:16 +08:00
zhayujie
cddb38ac3d Merge pull request #2673 from zhayujie/feat-web-console
feat: web console
2026-02-24 00:06:29 +08:00
zhayujie
394853c0fb feat: web console module display 2026-02-24 00:04:17 +08:00
zhayujie
c0702c8b36 feat: web channel stream chat 2026-02-23 22:19:50 +08:00
zhayujie
d610608391 feat: add cloud host config 2026-02-23 15:06:31 +08:00
zhayujie
9082eec91d feat: dark mode is used by default 2026-02-23 14:57:02 +08:00
zhayujie
f1a1413b5f feat: web console upgrade 2026-02-21 17:56:31 +08:00
zhayujie
c1e7f9af9b Merge pull request #2672 from zhayujie/feat-config-update
feat: cloud config update
2026-02-21 11:34:05 +08:00
zhayujie
1c71c4e38b feat: agent chat service 2026-02-21 00:39:36 +08:00
zhayujie
5e3eccb3f6 feat: support memory service 2026-02-20 23:44:05 +08:00
zhayujie
e1dc037eb9 feat: cloud skills manage 2026-02-20 23:23:04 +08:00
zhayujie
97e9b4c801 Merge branch 'master' into feat-config-update 2026-02-20 18:58:21 +08:00
zhayujie
52d7cad735 feat: support gemini-3.1-pro-preview and claude-4.6-sonnet 2026-02-20 12:14:59 +08:00
zhayujie
c0b1d270ba Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-19 14:18:39 +08:00
zhayujie
e59a2892e4 feat: support qwen3.5-plus 2026-02-19 14:18:16 +08:00
zhayujie
5fa0376a49 Merge pull request #2670 from SgtPepper114/fix/gemini-dingtalk-image-inline
fix(gemini): 修复钉钉图片标记未转多模态导致的识图失效
2026-02-19 13:57:04 +08:00
SgtPepper114
05a33042c8 fix(gemini): support dingtalk image markers as multimodal input
- parse [图片: path] markers in text and convert to Gemini inlineData parts

- unify reply path via call_with_tools to reuse multimodal conversion

- keep legacy safety behavior (BLOCK_NONE) and restore safety ratings logging on empty response

- add multimodal request image-part count log for debugging
2026-02-16 13:26:57 +00:00
zhayujie
ce58f23cbc feat: dashscope model name 2026-02-16 20:11:38 +08:00
zhayujie
b6fc9fa370 fix: run script dependency issues 2026-02-15 00:02:50 +08:00
zhayujie
00ae38faae docs: update models in README 2026-02-14 17:36:36 +08:00
zhayujie
ab28ee58ab feat: add doubao-2.0-code model and update README 2026-02-14 16:49:44 +08:00
zhayujie
48db538a2e feat: support Minimax-M2.5, glm-5, kimi-k2.5 2026-02-14 15:27:44 +08:00
zhayujie
46945942e1 feat: support channel start in sub thread 2026-02-13 12:38:52 +08:00
zhayujie
a24b26a1ef Merge pull request #2667 from cowagent/fix-wechatcom-image-support
fix: 支持企业微信图片消息识别功能
2026-02-12 16:44:18 +08:00
zhayujie
6f8421cdd5 fix: 支持企业微信图片消息识别功能
- 在 ChatGPTBot 中添加 ContextType.IMAGE 处理分支
- 新增 reply_image() 方法,支持 OpenAI Vision API
- 自动 Base64 编码图片并检测格式
- 自动清理临时文件

修复 #2625
2026-02-12 12:00:24 +08:00
zhayujie
284cd9bca9 Merge pull request #2666 from cowagent/fix-model-type-validation
fix: handle non-string model_type to prevent AttributeError
2026-02-10 11:31:45 +08:00
cowagent
23fd6b8d2b fix: handle non-string model_type to prevent AttributeError
When numeric model names (e.g., '1') are used with vLLM and configured
in YAML without quotes, they are parsed as integers. This causes
AttributeError when calling startswith() method.

Changes:
- Add type checking for model_type
- Convert non-string model_type to string with warning log
- Prevents crash when using custom numeric model names

Fixes #2664
2026-02-10 11:07:10 +08:00
zhayujie
4f0ea5d756 feat: make web search a built-in tool 2026-02-09 11:37:11 +08:00
zhayujie
6c218331b1 fix: improve skill system prompts and simplify tool descriptions
- Simplify skill-creator installation flow
- Refine skill selection prompt for better matching
- Add parameter alias and env variable hints for tools
- Skip linkai-agent when unconfigured
- Create skills/ dir in workspace on init
2026-02-08 18:59:59 +08:00
zhayujie
cea7fb7490 fix: add intelligent context cleanup #2663 2026-02-07 20:42:41 +08:00
zhayujie
8acf2dbdfe fix: chat context overflow #2663 2026-02-07 20:36:24 +08:00
zhayujie
0542700f90 fix: issues with empty tool calls and handling excessively long tool results 2026-02-07 20:25:05 +08:00
zhayujie
5264f7ce18 fix: getuid not found in windows 2026-02-07 11:17:58 +08:00
zhayujie
051ffd78a3 fix: windows path and encoding adaptation 2026-02-06 18:37:05 +08:00
zhayujie
bea95d4fae Merge pull request #2661 from cowagent/feat-add-claude-opus-4-6
feat: 添加 Claude Opus 4.6 模型支持
2026-02-06 15:09:49 +08:00
cowagent
fdf7bc312f feat: 添加 Claude Opus 4.6 模型支持
- 在 common/const.py 中添加 CLAUDE_4_6_OPUS 常量
- 将 claude-opus-4-6 添加到 MODEL_LIST
- 在 README.md 中更新 Agent 推荐模型列表
- 在 Claude 配置说明中添加 claude-opus-4-6 支持

Claude Opus 4.6 是 Anthropic 于 2026年2月5日发布的最新模型,
具有更强的规划能力和代码能力,适合作为 Agent 推荐模型。
2026-02-06 15:07:43 +08:00
vision
5b094e1097 Merge pull request #2660 from cowagent/fix-zhipuai-api-base-support
fix: 支持智谱AI自定义API base URL配置
2026-02-05 19:18:49 +08:00
cowagent
9ad3968084 fix: 支持智谱AI自定义API base URL配置
- 修复 ZhipuAiClient 初始化时未传入 base_url 参数的问题
- 使配置文件中的 zhipu_ai_api_base 配置项生效
- 支持智谱国际版(z.ai)等自定义API端点
- 同时修复对话和图片生成功能
- 添加日志输出便于确认使用的API地址

Fixes #2659
2026-02-05 19:06:46 +08:00
zhayujie
3958b6aae1 Merge pull request #2657 from cowagent/fix-missing-runtime-info-parameter
fix: 补充缺失的 runtime_info 参数传递
2026-02-04 22:51:53 +08:00
cowagent
eaa413caf0 fix: 补充缺失的 runtime_info 参数传递
问题:
PR #2655 已合并,但遗漏了关键的参数传递环节。runtime_info 在 agent_initializer.py 中创建并传递给 create_agent(),但 agent_bridge.py 的 create_agent() 方法中没有将其传递给 Agent 实例,导致动态时间更新功能无法生效。

影响:
- Agent 实例的 self.runtime_info 为 None
- get_full_system_prompt() 无法检测到动态时间函数
- 时间戳仍然是静态的,不会实时更新

修复:
在 agent_bridge.py 第 236 行添加:
runtime_info=kwargs.get("runtime_info")

这确保了完整的参数传递链路:
agent_initializer → agent_bridge.create_agent → Agent.__init__

---

*来自 [CowAgent](https://github.com/zhayujie/chatgpt-on-wechat) 项目的 AI Agent*
2026-02-04 22:49:54 +08:00
zhayujie
9095225b5b Merge pull request #2656 from 6vision/master
Update: improve script interaction and configuration
2026-02-04 22:46:02 +08:00
zhayujie
c529f86dbc Merge pull request #2655 from cowagent/fix-runtime-timestamp-update
fix: 动态更新系统提示词中的运行时信息(时间戳)
2026-02-04 22:38:51 +08:00
cowagent
e4fcfa356a refactor: 改用动态函数实现运行时信息更新(更健壮的方案)
改进点:
1. builder.py: _build_runtime_section() 支持 callable 动态时间函数
2. agent_initializer.py: 传入 get_current_time 函数而非静态时间值
3. agent.py: _rebuild_runtime_section() 动态调用时间函数并重建该部分

优势:
- 解耦模板:不依赖具体的提示词格式
- 健壮性:提示词模板改变不会导致功能失效
- 向后兼容:保留对静态时间的支持
- 性能优化:只在需要时才计算时间

相比之前的正则匹配方案,这个方案更加优雅和可维护。
2026-02-04 22:37:19 +08:00
vision
8218cff7c1 Merge branch 'zhayujie:master' into master 2026-02-04 22:32:20 +08:00
6vision
6949bbcf39 update: Improve script interaction and configuration 2026-02-04 22:31:40 +08:00
cowagent
480c60c0a7 fix: 动态更新系统提示词中的运行时信息(时间戳)
问题:
- system_prompt 在 Agent 初始化时固定,导致模型获取的时间信息过时
- 长时间运行的会话中,模型对时间判断不准确

解决方案:
- 在 get_full_system_prompt() 中添加动态更新逻辑
- 每次获取系统提示词时,使用正则表达式替换运行时信息中的时间戳
- 保持其他运行时信息(模型、工作空间等)不变

测试:
- 创建测试脚本验证时间动态更新功能
- 等待3秒后时间正确更新(22:19:45 -> 22:19:48)
2026-02-04 22:27:24 +08:00
zhayujie
eec10cb5db fix: claude remove toolname 2026-02-04 22:15:10 +08:00
zhayujie
02c83d8689 docs: update agent.md 2026-02-04 21:42:52 +08:00
zhayujie
72b1cacea1 fix: hiding the thought process 2026-02-04 19:36:01 +08:00
zhayujie
c72cda3386 fix: minimax reasoning content optimization 2026-02-04 19:26:36 +08:00
zhayujie
867442155e fix: lark connection issue 2026-02-04 17:05:30 +08:00
zhayujie
229b14b6fc fix: feishu cert error 2026-02-04 16:15:38 +08:00
zhayujie
158c87ab8b fix: openai function call 2026-02-04 15:42:43 +08:00
zhayujie
cb303e6109 fix: add decision round log 2026-02-03 21:27:30 +08:00
saboteur7
a77a8741b5 fix: memory loss issue caused by scheduler 2026-02-03 20:45:22 +08:00
zhayujie
3d63459c25 docs: update README.md 2026-02-03 15:44:00 +08:00
saboteur7
ce63de3c58 feat: release 2.0.0 2026-02-03 14:48:30 +08:00
saboteur7
4b3b1219b5 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-03 12:20:04 +08:00
saboteur7
73b069a76c docs: update 2.0 README.md 2026-02-03 12:19:36 +08:00
Saboteur7
101cf8d108 Merge pull request #2653 from 6vision/deploy-script
feat: enhance one-click deployment script with full lifecycle management
2026-02-03 03:18:49 +08:00
saboteur7
2e926dfb6e fix: python 3.8 compatibility issues 2026-02-03 03:17:11 +08:00
saboteur7
501866d12a feat: optimize document and model usage 2026-02-03 02:58:15 +08:00
6vision
39bcb0869f feat: enhance one-click deployment script with full lifecycle management 2026-02-03 02:56:46 +08:00
saboteur7
a7b99cde4e Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2026-02-03 01:18:17 +08:00
saboteur7
60abcd92a3 feat: update README.md and solving Python compatibility issues 2026-02-03 01:17:25 +08:00
zhayujie
cdd36e7052 docs: update README.md 2026-02-03 00:48:03 +08:00
saboteur7
c6ac175ce4 docs: update README.md 2026-02-03 00:43:42 +08:00
zhayujie
46bcd87c23 feat: support minimax M2 models 2026-02-02 23:36:23 +08:00
zhayujie
ab74be8e33 feat: add qwen models tool call 2026-02-02 23:08:24 +08:00
zhayujie
d8298b3eab fix: support glm-4.7 2026-02-02 22:43:08 +08:00
zhayujie
50e60e6d05 fix: bug fixes 2026-02-02 22:22:10 +08:00
zhayujie
5d02acbf37 config: add config template 2026-02-02 14:25:34 +08:00
zhayujie
8901d91f96 feat: startup log optimization 2026-02-02 12:25:47 +08:00
zhayujie
b55021bb3d feat: system Initialization log 2026-02-02 12:18:57 +08:00
zhayujie
0ef51b85e6 Merge branch 'feat-cow-agent' 2026-02-02 12:03:55 +08:00
zhayujie
c77566cc02 fix: adjust the maximum step size 2026-02-02 12:03:16 +08:00
zhayujie
c1bcedfb51 Merge pull request #2652 from zhayujie/feat-cow-agent
feat: cow super agent
2026-02-02 11:59:45 +08:00
zhayujie
d085a3c7d7 fix: dingtalk picture and file process 2026-02-02 11:58:19 +08:00
zhayujie
46fa07e4a9 feat: optimize agent configuration and memory 2026-02-02 11:48:53 +08:00
zhayujie
a8d5309c90 feat: add skills and upgrade feishu/dingtalk channel 2026-02-02 00:42:39 +08:00
zhayujie
77c2bfcc1e fix: scheduler in feishu 2026-02-01 19:40:27 +08:00
zhayujie
4c8712d683 feat: key management and scheduled task tools 2026-02-01 19:21:12 +08:00
zhayujie
d337140577 feat: optimize editing tools 2026-02-01 17:46:43 +08:00
zhayujie
99c273a293 fix: write too long file 2026-02-01 17:29:48 +08:00
zhayujie
85578a06b7 fix: memory edit bug 2026-02-01 17:13:32 +08:00
zhayujie
6f70a8efda fix: fts5 not available bug 2026-02-01 17:08:02 +08:00
zhayujie
c693e39196 feat: improve the memory system 2026-02-01 17:04:46 +08:00
zhayujie
4a1fae3cb4 chore: the bot directory was changed to models 2026-02-01 15:21:28 +08:00
zhayujie
08b592816b Merge pull request #2651 from zhayujie/feat-cow-agent
fix: optimize suggestion words and retries
2026-02-01 14:11:53 +08:00
zhayujie
0e85fcfe51 fix: optimize suggestion words and retries 2026-02-01 14:00:28 +08:00
zhayujie
8ef788e799 Merge pull request #2650 from zhayujie/feat-cow-agent
feat: cow agent
2026-02-01 13:14:00 +08:00
zhayujie
645c8899b1 fix: remove tool 2026-02-01 12:38:00 +08:00
zhayujie
9bf5b0fc48 fix: tool call failed problem 2026-02-01 12:31:58 +08:00
zhayujie
07959a3bff fix: first conversation bug 2026-01-31 17:53:12 +08:00
zhayujie
86a6182e41 fix: add logs 2026-01-31 17:29:32 +08:00
zhayujie
89e229ab75 feat: prompt optimization 2026-01-31 17:13:55 +08:00
zhayujie
624917fac4 fix: memory and path bug 2026-01-31 16:53:33 +08:00
zhayujie
489894c61d fix: path prompt 2026-01-31 16:05:20 +08:00
zhayujie
ac87979cb7 fix: bash prompt optimize 2026-01-31 16:01:37 +08:00
zhayujie
5fd3e85a83 feat: add llm retry 2026-01-31 15:53:24 +08:00
zhayujie
0e53ba4311 fix: gemini error process 2026-01-31 14:59:55 +08:00
Saboteur7
3ce57ef851 Merge pull request #2648 from zhayujie/feat-cow-agent
feat: cow agent core
2026-01-31 13:14:05 +08:00
zhayujie
481570d059 fix: invalid syntax 2026-01-31 13:07:51 +08:00
zhayujie
04442b7ddb fix: prompt optimization and gemini fix 2026-01-31 13:02:58 +08:00
zhayujie
e1a71723bc fix: gemini support api base 2026-01-31 12:50:21 +08:00
zhayujie
f044fb8b47 feat: add feishu websocket mode 2026-01-31 12:32:41 +08:00
zhayujie
e3350d5bec feat: optimize prompts and skill creator 2026-01-31 11:20:57 +08:00
saboteur7
8a69d4354e feat: Optimize the first dialogue and memory 2026-01-30 19:10:37 +08:00
saboteur7
dd6a9c26bd feat: support skills creator and gemini models 2026-01-30 18:00:10 +08:00
saboteur7
49fb4034c6 feat: support skills 2026-01-30 14:27:03 +08:00
saboteur7
5a466d0ff6 fix: long-term memory bug 2026-01-30 11:31:13 +08:00
saboteur7
bb850bb6c5 feat: personal ai agent framework 2026-01-30 09:53:46 +08:00
saboteur7
25cf6823d0 fix: remove useless files 2026-01-29 20:00:23 +08:00
vision
7e12744b8b Merge pull request #2634 from 6vision/master
update: delet some banwords
2025-10-22 18:32:10 +08:00
vision
8f2432e0f8 Merge pull request #2632 from 6vision/banwords-delet
Update: delet some bangwords
2025-10-22 17:00:26 +08:00
6vision
94451db638 update: delet some bangwords 2025-10-22 16:58:40 +08:00
zhayujie
f8b8eeec3a Merge pull request #2622 from 6vision/support_gpt-5
feat:Support for the GPT-5 series models
2025-08-08 10:47:49 +08:00
6vision
a4260cc5de feat:Support for the GPT-5 series models 2025-08-08 10:24:15 +08:00
zhayujie
8c1622798b Merge pull request #2612 from 6vision/master
docs: expand channel usage
2025-06-29 22:41:10 +08:00
6vision
e75bed1be5 docs: update README.md 2025-06-29 18:34:49 +08:00
vision
8c0517de0f Merge branch 'zhayujie:master' into master 2025-06-29 17:49:44 +08:00
6vision
94e78365a5 docs: expand channel usage 2025-06-29 17:49:26 +08:00
vision
29c056ca65 Merge pull request #2611 from 6vision/web_channel_update
refactor: improve logger message to use dynamic port
2025-06-29 17:20:00 +08:00
vision
d8c57f27db Merge branch 'zhayujie:master' into master 2025-06-29 17:17:59 +08:00
6vision
3cac2bad55 refactor: improve logger message to use dynamic port 2025-06-29 17:12:28 +08:00
vision
e7905fdf49 docs: expand channel usage
Improve channel integration docs
2025-06-26 19:27:11 +08:00
vision
a492bc2242 docs: expand channel usage 2025-06-26 19:24:39 +08:00
zhayujie
e663364f64 Merge pull request #2609 from 6vision/master
docs: update README.md
2025-06-24 20:45:28 +08:00
6vision
ef6466e26f docs: update README.md 2025-06-24 20:33:52 +08:00
6vision
7fcbbf1cdc docs: update README.md 2025-06-24 17:24:01 +08:00
6vision
ec6ad51ff7 docs: update README.md 2025-06-24 17:20:53 +08:00
zhayujie
1e80c59448 docs: update README.md 2025-06-15 17:44:44 +08:00
zhayujie
e48cb4fd5d chore: remove useless files 2025-06-15 17:33:40 +08:00
zhayujie
7c9fbd2625 docs: improve the readme document 2025-06-15 17:31:41 +08:00
zhayujie
0f504415fb docs: optimize the documentation 2025-06-15 12:42:05 +08:00
zhayujie
4998c324d1 fix: remove chat prefix in web channel 2025-06-07 15:30:22 +08:00
zhayujie
fb5fbe76e8 docs: update docs 2025-05-30 17:06:40 +08:00
zhayujie
223b0bfc88 docs: update README.md 2025-05-30 17:05:04 +08:00
vision
51094a68c8 feat: update Gemini models 2025-05-25 17:44:28 +08:00
6vision
83cb1ec911 feat: update Gemini models 2025-05-25 17:39:17 +08:00
vision
a77e4bfb7a Merge pull request #2596 from 6vision/master
feat: support claude-4-opus and claude-4-sonnet models
2025-05-23 17:19:05 +08:00
6vision
654c177333 docs: update readme.md 2025-05-23 17:12:58 +08:00
vision
b92669ba33 Merge branch 'zhayujie:master' into master 2025-05-23 17:08:23 +08:00
6vision
f2e4f6607d feat:support claude-4-opus and claude-4-sonnet models 2025-05-23 17:07:46 +08:00
zhayujie
5ec909c565 docs: update readme.md 2025-05-23 16:54:58 +08:00
vision
a84f31d54a Merge pull request #2592 from thzjy/fix-1037-baidu-voice
fix: 修复百度语音合成长文处理
2025-05-23 15:14:11 +08:00
vision
e0dd21406d Update baidu_voice.py 2025-05-23 15:13:28 +08:00
vision
72f5f7a0b8 Merge pull request #2565 from dhyarcher/master
Fix access_token expiration handling by processing expires_in and ref…
2025-05-23 14:31:16 +08:00
zhayujie
e3d20085c5 Merge pull request #2595 from zhayujie/feat-agent-plugin
feat: add agent plugin and optimize web channel
2025-05-23 11:59:54 +08:00
zhayujie
8bf1aef801 docs: add web channel and agent plugin docs 2025-05-23 11:56:41 +08:00
Saboteur7
5f7ade20dc feat: web channel support multiple message and picture display 2025-05-23 00:43:54 +08:00
Saboteur7
70d7e52df0 feat: 优化agent插件及webUI对话页面 2025-05-22 17:31:32 +08:00
zhayujie
8e6afa5614 Merge pull request #2593 from zhayujie/feat-web-ui
feat: web ui channel optimization
2025-05-19 11:48:34 +08:00
Saboteur7
a1ae3804e3 feat: web ui channel optimization 2025-05-19 11:41:20 +08:00
thzjy
814ce7a43b fix: 修复百度语音合成长文处理 2025-05-18 17:32:17 +08:00
Saboteur7
628f75009e Merge pull request #2591 from zhayujie/feat-web-ui
feat: new web UI channel
2025-05-18 16:57:57 +08:00
Saboteur7
03fc8c1202 feat: web ui channel update 2025-05-18 16:56:50 +08:00
Saboteur7
8c8e996c87 feat: web channel optimization 2025-05-18 15:23:02 +08:00
vision
933bb0b1fb Merge pull request #2579 from 6vision/web_channel_bug_fix
Fix: fix 'NoneType' object does not support item assignment error (#2525)
2025-04-20 17:22:54 +08:00
6vision
931fbc3eb5 fix: fix 'NoneType' object does not support item assignment error (#2525)
### Problem Description
When `context` is `None`, it should not be used for assignment operations.

### Solution
Adjusted the code logic to ensure that `context` is not `None` before performing any item assignment.
2025-04-20 16:27:44 +08:00
Saboteur7
3db5e70a3d docs: Update README.md 2025-04-15 09:54:24 +08:00
zhayujie
7b19b70d90 Merge pull request #2575 from 6vision/master
feat: support gpt-4.1 series models
2025-04-15 09:25:02 +08:00
6vision
99b8103d70 feat: support gpt-4.1 series models 2025-04-15 09:15:13 +08:00
vision
7167310ccd Merge pull request #2571 from 6vision/master
update readme and adjust some dependency packages.
2025-04-11 16:04:55 +08:00
6vision
263667a2d4 update 2025-04-11 16:03:22 +08:00
6vision
d5cef291f6 update readme and adjust some dependency packages. 2025-04-11 15:50:28 +08:00
vision
c8d166e833 Merge pull request #2544 from wahahage/master
新增腾讯语音
2025-04-11 14:14:55 +08:00
vision
6e25782d8b docs: Delete channel/wechat/README.md 2025-04-11 10:23:05 +08:00
vision
c3127f7e84 Merge pull request #2562 from josephier/support_wcferry
feat: add support for WeChat integration via the wcferry protocol
2025-04-09 18:51:01 +08:00
dhyarcher
7b90fb018b Fix access_token expiration handling by processing expires_in and refreshing the token when expired;修复 access_token 过期处理,添加对 expires_in 的处理并在过期时刷新 token; 2025-04-03 10:13:57 +08:00
josephier
e8bc173cd7 doc: Update and rename readme.md to README.md 2025-03-31 19:39:01 +08:00
josephier
4d1cdf5207 doc:update git url 2025-03-30 16:20:04 +08:00
josephier
57a473364e Merge branch 'zhayujie:master' into master 2025-03-30 15:14:45 +08:00
vision
40b62e9d38 Add support for ModelScope API-Inference
Add support for ModelScope API-Inference
2025-03-30 15:12:29 +08:00
gaojia
ead5f9926b 删除funasr 2025-03-27 10:13:38 +08:00
gaojia
814b6753c2 删除配置文件中的注释 2025-03-26 17:33:39 +08:00
gaojia
ce505251f8 修改配置文件及文件夹名称 2025-03-26 10:01:41 +08:00
yrk
5d2a987aaa Update README.md 2025-03-25 10:38:32 +08:00
yanrk123
4d67e08723 Fix the issue with Chinese description in drawing. 2025-03-18 14:11:22 +08:00
yanrk123
2e71dd5fe2 Fix bug in modelscope_bot.py 2025-03-18 09:47:39 +08:00
yanrk123
c3b9643227 Modify ms_bot.py 2025-03-17 15:46:50 +08:00
josephier
0aad5dc2b7 Update wcferry version
Update wcferry version
2025-03-16 19:16:59 +08:00
yanrk123
cec900168f Modify model list 2025-03-14 13:56:00 +08:00
josephier
f9b1c403d5 docs: Update readme.md 2025-03-12 20:33:35 +08:00
yrk111222
9024b602f5 Update modelscope_bot.py 2025-03-12 16:15:40 +08:00
yanrk123
c139fd9a57 support stream mode for QwQ-32B 2025-03-12 15:45:52 +08:00
yrk111222
e299b68163 Update const.py 2025-03-11 16:48:37 +08:00
yanrk123
7777a53a82 Add supported model list 2025-03-11 16:34:43 +08:00
yanrk123
3e185dbbfe Add support for ModelScope API 2025-03-11 11:12:57 +08:00
josephier
e8a32af369 docs: add README for wx channel based on wcferry
docs: add README for wx channel based on wcferry
2025-03-10 20:36:41 +08:00
josephier
7b0ec6687e docs:add README for WechatFerry channel 2025-03-10 20:29:37 +08:00
gaojia
ec1c6c7b92 新增腾讯语音 2025-03-04 09:56:26 +08:00
josephier
8dfaa86760 chore: remove incomplete features for wchatferry 2025-02-14 00:41:31 +08:00
josephier
323aebd1be feat: add support for WeChat integration via the wchatferry 2025-02-14 00:25:09 +08:00
Saboteur7
436c038a2f fix: temporarily remove unavailable channels 2025-02-05 12:25:30 +08:00
vision
ccd50ec6c0 Merge pull request #2485 from 6vision/master
feat: Add support for deepseek-chat and deepseek-reasoner models
2025-02-04 10:29:24 +08:00
6vision
a7541c2c0f feat: Support #model directive to set model to deepseek-chat and deepseek-reasoner 2025-02-03 21:23:05 +08:00
Saboteur7
c3a57d756c fix: remove channel restrictions 2025-01-31 00:27:20 +08:00
Saboteur7
aa300a4c98 fix: temporarily close the wx channel to prevent account ban 2025-01-17 17:24:42 +08:00
vision
83ea7352b9 Merge pull request #2430 from PJ-568/master
fix: domain type of xunfei lite
2025-01-15 20:03:43 +08:00
Saboteur7
9050712cd8 Update README.md 2024-12-28 16:28:35 +08:00
Saboteur7
8d92fdbb6e Update README.md 2024-12-28 16:27:31 +08:00
zhayujie
a2442ec1b9 Merge pull request #2435 from 6vision/master
fix: resolve display issue for replies containing only image URLs
2024-12-27 00:02:55 +08:00
vision
71662c9cd9 Merge branch 'zhayujie:master' into master 2024-12-26 23:17:21 +08:00
vision
54ff5dbcc2 fix: resolve display issue for replies containing only URLs 2024-12-26 23:16:05 +08:00
zhayujie
4ab7bd3b51 Merge pull request #2431 from 6vision/support-GiteeAI
feat: add gitee-ai models that are compatible with openai format
2024-12-24 20:42:17 +08:00
vision
ef3c61a297 update readme 2024-12-24 19:57:26 +08:00
vision
abf79bf60c add gitee-ai model resources that are compatible with openai format 2024-12-21 17:24:32 +08:00
PJ568
5d3cecd926 fix: domain type of xunfei lite
Reference: [Web API 接口说明](https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E)的 `parameter.chat部分`。
2024-12-20 14:46:25 +08:00
Saboteur7
16324e7283 Merge pull request #2407 from ayasa520/fix_reloadp
fix(plugin): fix reloadp command not taking effect
2024-12-13 15:39:33 +08:00
Saboteur7
9f7e2e1572 Merge pull request #2413 from ayasa520/fix-scanp
fix: Memory leak caused by scanp command due to handler's reference of plugin instance
2024-12-13 14:57:22 +08:00
vision
857ce1d530 Merge pull request #2398 from stonyz/web-channel
增加web channel
2024-12-13 11:45:01 +08:00
vision
be0d72775d Merge pull request #2423 from 6vision/reedme_update_docker_deploy
update readme
2024-12-13 11:41:17 +08:00
vision
7832a2495b Merge pull request #2422 from printlndarling/master
add: add gemini-2.0-flash-exp model
2024-12-13 11:35:26 +08:00
6vision
0506b7f735 update readme 2024-12-13 11:25:36 +08:00
繁星_逐梦
4c0b7942f0 add: gemini-2.0-flash-exp model 2024-12-12 22:22:14 +08:00
繁星_逐梦
651c840c4a add: gemini-2.0-flash-exp model 2024-12-12 22:19:13 +08:00
rikka
2a351ca415 fix(reloadp): clear handlers when reloading plugin to avoid memory leaks 2024-12-05 00:33:00 +08:00
rikka
49b7106d71 fix: Memory leak caused by scanp command due to handler's reference to plugin instance.
close #2412
2024-12-03 22:39:56 +08:00
zhayujie
8bf633f539 Merge pull request #2408 from 6vision/fix-summary-image
图像识别逻辑优化
2024-12-02 21:53:52 +08:00
6vision
0f8efcb4b0 图像识别逻辑优化 2024-12-02 21:16:59 +08:00
Rikka
c567641c5c fix(plugin): fix reloadp command not taking effect
- Use write_plugin_config() instead of directly modifying plugin_config dict
- Add remove_plugin_config() to clear plugin config before reload
- Update plugins to use pconf() and write_plugin_config() for better config management
2024-12-02 16:38:21 +08:00
vision
bdc3820382 Merge pull request #2405 from 6vision/role-plugin-linkai
Linkai bot is compatible with the role plugin.
2024-12-02 12:16:30 +08:00
6vision
33a69a7907 Linkai bot is compatible with the role plugin. 2024-12-02 12:13:26 +08:00
vision
a4d0e9bbc3 Merge pull request #2401 from 6vision/plugins_source_update
插件列表更新
2024-11-29 11:09:27 +08:00
6vision
afc753e1d2 插件列表更新 2024-11-29 11:07:16 +08:00
zhayujie
e641a41224 Update README.md 2024-11-28 21:48:42 +08:00
vision
79305c0632 Merge pull request #2400 from 6vision/readme_update
readme update
2024-11-28 12:59:00 +08:00
6vision
ef2ce3f09d 说明文档更新 2024-11-28 12:41:00 +08:00
Stony
71c18c04fc 增加web channel 2024-11-27 08:53:13 +08:00
Saboteur7
cf84e57f81 fix: add exception handling 2024-11-15 11:58:10 +08:00
vision
9421d44579 Merge pull request #2373 from 6vision/summary_app_code
Buy using app code, supports custom summary prompt .
2024-11-07 20:16:53 +08:00
6vision
5cd2ae8cc8 Summary supports app_code 2024-11-06 21:45:03 +08:00
vision
22d67b3a59 Merge pull request #2364 from 6vision/1031
1.7.3 release readme
2024-10-31 14:44:55 +08:00
6vision
e102cbb8c4 1.7.3 release readme 2024-10-31 14:39:11 +08:00
vision
d90eeb7ee4 Merge pull request #2363 from 6vision/linkai_plugin
Summary and MJ  support can be configured through LinkAI platform app plugins
2024-10-31 11:50:53 +08:00
vision
1989d53031 Merge pull request #2361 from 6vision/claude_model_update
Claude model update
2024-10-31 11:50:11 +08:00
6vision
04ef0907b4 Summary and MJ support can be configured through LinkAI platform app plugins. 2024-10-31 11:15:44 +08:00
6vision
517b43561c Merge branch 'claude_model_update' of git@github.com:6vision/chatgpt-on-wechat.git into claude_model_update 2024-10-28 00:32:46 +08:00
6vision
ccb8c7227f Support setting base URL and proxy for Claude model. Also support reset command. 2024-10-28 00:32:05 +08:00
vision
9fbfeeb04f Merge branch 'zhayujie:master' into claude_model_update 2024-10-27 23:43:16 +08:00
6vision
8b753a5a1f Signed-off-by: 6vision <vision_wangpc@sina.com> 2024-10-27 21:44:06 +08:00
6vision
d25cab0627 Claude model supports system prompts. 2024-10-27 21:37:58 +08:00
6vision
84da0a8a35 feat:update claude-35-sonnet model 2024-10-24 20:57:03 +08:00
vision
6f665cffba Merge pull request #2354 from 6vision/group_patpat_note
fix: group patpat notes
2024-10-24 19:53:18 +08:00
6vision
aea8ac2e97 Signed-off-by: 6vision <vision_wangpc@sina.com> 2024-10-24 19:48:50 +08:00
vision
8418fa7b45 Merge pull request #2344 from 6vision/markdown_format_display
Optimize markdown format display
2024-10-21 10:27:03 +08:00
6vision
9cc4d0ee07 Optimize markdown format display 2024-10-21 10:23:39 +08:00
Saboteur7
da60831c44 fix: fixed the version of qrcode dependency 2024-10-19 16:14:49 +08:00
Saboteur7
0773174a20 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-10-19 15:55:04 +08:00
Saboteur7
70e007d8ca fix: try to solve the unresponsiveness problem 2024-10-19 15:49:57 +08:00
vision
fcc4d02c2f Merge pull request #2339 from 6vision/master
Optimize Gemini model character statistics
2024-10-14 12:19:27 +08:00
vision
f4a5f00593 Merge branch 'zhayujie:master' into master 2024-10-14 12:18:33 +08:00
6vision
1170ed6566 Optimize Gemini model character statistics 2024-10-14 12:17:10 +08:00
zhayujie
883f0d449b Merge pull request #2317 from 6vision/master
feat: add install.sh and run.sh
2024-09-26 16:43:56 +08:00
6vision
f4c62e7844 update install.sh url 2024-09-26 16:43:12 +08:00
6vision
f0d212a9d2 Merge branch 'master' of github.com:6vision/chatgpt-on-wechat 2024-09-26 16:02:19 +08:00
6vision
76a8974034 update run.sh 2024-09-26 16:01:44 +08:00
vision
0614e822f4 Merge branch 'zhayujie:master' into master 2024-09-26 13:07:45 +08:00
vision
6f682c9a2e Merge pull request #2311 from cmgzn/master
fix: gemini doesn't receive system messages...
2024-09-26 13:04:47 +08:00
6vision
a9fdbc31c5 update date 2024-09-26 13:02:38 +08:00
cmgzn
086fdb5856 fix gemini logger 2024-09-26 02:49:52 +01:00
6vision
63c8ef4f17 feat: install.sh and run.sh 2024-09-26 00:34:52 +08:00
zhayujie
736f6523c7 Merge branch 'master' into master 2024-09-25 23:11:13 +08:00
vision
8b0b360d25 Merge pull request #2288 from KuroIVeko/patch-3
Support more models from Zhipu AI
2024-09-25 22:28:16 +08:00
vision
80b84e2ee6 Merge pull request #2277 from KuroIVeko/patch-1
Lower Gemini's safety thresholds
2024-09-25 22:24:20 +08:00
vision
b5b7d86f7b Merge pull request #2278 from 6vision/moonshoot
fix: "model":"mooshoot", which defaults to "moonshot-v1-32k".
2024-09-25 22:10:40 +08:00
cmgzn
f20d704390 fix: gemini doesn't receive system messages; change session to gpt method, add system messages as user messages to the gemini, and logging historical messages 2024-09-20 09:10:21 +01:00
vision
e4e1e2e944 Merge pull request #2306 from 6vision/master
fix: Linkai voice configuration
2024-09-18 19:43:41 +08:00
vision
6bc7eeb4cc Merge branch 'zhayujie:master' into master 2024-09-18 19:41:23 +08:00
6vision
656ed5de7b fix: LinkAI voice onfiguration 2024-09-18 19:40:51 +08:00
zhayujie
a11d695c78 Merge pull request #2300 from 6vision/master
feat: support o1-preview and o1-mini model
2024-09-13 10:50:04 +08:00
6vision
c4f9acd5c5 update 2024-09-13 10:48:51 +08:00
6vision
5ef929dc42 o1 model support #model 2024-09-13 10:21:38 +08:00
6vision
c8cf27b544 feat: support o1-preview and o1-mini model 2024-09-13 10:13:23 +08:00
vision
bb5ecfc398 Merge pull request #2298 from 6vision/error_print_ascii_windows
Handle ASCII QR code print error on Windows
2024-09-11 22:35:30 +08:00
6vision
c91e7c35bb Remove unused imports 2024-09-11 22:34:33 +08:00
6vision
532d56df2d Handle ASCII QR code print error on Windows 2024-09-11 22:30:25 +08:00
KurolVeko
111ad44029 Update const.py 2024-09-05 11:07:06 +08:00
KurolVeko
6b02bae957 Update bridge.py 2024-09-05 10:59:57 +08:00
vision
6831743416 Merge pull request #2286 from 6vision/gpt
feat: support gpt-4o-2024-08-06 model
2024-09-04 18:44:08 +08:00
6vision
63e2f42636 feat: support gpt-4o-2024-08-06 model 2024-09-04 18:39:29 +08:00
6vision
f6e6805453 fix: "model":"mooshoot", which defaults to "moonshot-v1-32k". 2024-08-31 16:09:10 +08:00
KurolVeko
ad77ad8f2b Lower Gemini's safety thresholds
Gemini's default safety thresholds are set too high, resulting in frequent censorship of generated text. I have lowered the thresholds for all four safety categories according to Google's documentation.
2024-08-30 17:00:51 +08:00
Saboteur7
469524e8ae Merge pull request #2206 from VanJohnPK/master
fix azure voice error 修复Azure语音服务报错问题
2024-08-29 11:33:49 +08:00
Saboteur7
f4f55d5dfd Merge pull request #2247 from byang822/abacusoft-alex
wenxin character model supports prompt
2024-08-29 11:31:45 +08:00
Saboteur7
c248d0f3f4 Merge pull request #2262 from 6vision/cancel_wecom_subscribe
Cancel subscribe_msg of wechatcomapp channel
2024-08-29 11:31:04 +08:00
Saboteur7
648a04b513 Merge pull request #2265 from 6vision/feat0825
Support configuration whether to be @ in group chat.
2024-08-29 11:30:46 +08:00
vision
bdc86c16ec Merge pull request #2268 from 6vision/xunfei_system_prompt
Xunfei supports system prompt(character_desc).
2024-08-27 20:46:07 +08:00
6vision
21efd17c17 Xunfei supports system prompt(character_desc). 2024-08-25 22:22:29 +08:00
Saboteur7
aaa75e7b62 Merge pull request #2267 from 6vision/master
Optimize the welcome message for new members.
2024-08-25 17:16:11 +08:00
6vision
6d0cef3152 Optimize the welcome message for new members. 2024-08-25 17:10:44 +08:00
Saboteur7
c18472289f Merge pull request #2207 from Abyss-Seeker/master
支持更多语言(英语)的微信客户端
2024-08-25 16:10:33 +08:00
6vision
02b7c70a81 Support configuration whether to be @ in group chat. 2024-08-25 15:13:25 +08:00
6vision
4eaa2b93c6 Cancel subscribe_msg of wechatcomapp channel 2024-08-22 22:03:04 +08:00
darkVinci
d347905373 Merge pull request #1 from zhayujie/master
merge 15 commits
2024-08-21 11:21:31 +08:00
vision
f495213b2c Merge pull request #2237 from 6vision/fix_role
Optimize log information printing
2024-08-17 17:01:08 +08:00
Alex Yang
9b125913ae wenxin character model supports prompt 2024-08-16 14:58:17 +08:00
6vision
da81f05804 Optimize log information printing 2024-08-14 23:03:57 +08:00
Abyss-Seeker
9a371a4d4d Update wechat_message.py
加入更多英文适配(通过QR code加入群聊)
2024-08-06 23:30:32 +08:00
Abyss-Seeker
1e92828f1a 支持更多语言(英语)
加入了notes_join_group,notes_exit_group,notes_patpat列表,可以在加入群聊,退出群聊和拍一拍消息中匹配更多的字符。在此完成了英语(invited, removed, tickled)的匹配,使如果微信语言是英文的话也可以正常识别啦!同时,以后也可以通过加list和判断语句的方式支持更多语言!
2024-08-04 10:14:23 +08:00
Saboteur7
7e724b3fa3 Update README.md 2024-08-02 16:06:25 +08:00
vision
3f5b976a87 Merge pull request #2181 from 6vision/webp_images
Support images in webp format.
2024-08-02 13:47:39 +08:00
vision
49f2339cc2 Merge pull request #2203 from 6vision/fix_issues
Fix issues
2024-08-02 13:30:14 +08:00
vision
29f1699de8 Merge pull request #2198 from 6vision/update_spark
Support Spark4.0 Ultra model, optimize model configuration.
2024-08-02 01:38:15 +08:00
6vision
c415485801 Support Spark4.0 Ultra model, optimize model configuration. 2024-08-01 17:57:48 +08:00
zhayujie
6937673472 Merge pull request #2193 from 6vision/fix_tool
Default close tool plugin.
2024-07-31 14:09:33 +08:00
6vision
c4f10fe876 fix: Default close tool plugin. 2024-07-31 00:01:56 +08:00
6vision
55ca652ad8 Default close tool plugin. 2024-07-30 23:14:23 +08:00
Zheng
3effd5afd1 fix azure voice error 2024-07-30 17:10:02 +08:00
Saboteur7
000c2029de fix: remove some tools 2024-07-30 12:35:12 +08:00
Saboteur7
ab88e3af06 fix: remove some default tools 2024-07-30 12:15:35 +08:00
6vision
b544a4c954 fix: Use default expiration time for ExpiredDict if not set in config 2024-07-29 20:14:41 +08:00
6vision
baff5fafec Optimization 2024-07-28 00:03:16 +08:00
6vision
1673de73ba Role plugin supports more bots. 2024-07-25 22:58:57 +08:00
6vision
e68936e36e Support images in webp format. 2024-07-25 01:19:44 +08:00
6vision
7dbd195e45 Support images in webp format. 2024-07-25 01:12:53 +08:00
vision
3dc22f98bf Merge pull request #2177 from 6vision/Opti-azure-dalle
Optimize error messages when using Azure Dalle
2024-07-24 12:38:13 +08:00
6vision
805e870c18 Optimize error messages when using Azure Dalle 2024-07-24 00:06:18 +08:00
Saboteur7
de2c031797 docs: update readme 2024-07-19 15:46:19 +08:00
Saboteur7
3aa571aa1b Merge pull request #2163 from 6vision/wechatcom_app
Ensure compatibility for /wxcomapp URL with trailing slash
2024-07-19 15:38:20 +08:00
Saboteur7
3e4969efe6 Merge branch 'master' into wechatcom_app 2024-07-19 15:38:08 +08:00
Saboteur7
446e94df76 Merge pull request #2164 from 6vision/mini_bot
Support gpt-4o-mini model
2024-07-19 15:37:30 +08:00
Saboteur7
5b26066a4c Merge pull request #2154 from distiny-cool/ali_api
增加了使用阿里云进行语音识别的引擎
2024-07-19 15:37:05 +08:00
Saboteur7
8a80de5c3f Merge pull request #2141 from Yanyutin753/new
PictureChange插件功能升级
2024-07-19 15:36:02 +08:00
6vision
52a490c87e Support gpt-4o-mini model 2024-07-19 11:04:45 +08:00
6vision
29490741fd Ensure compatibility for /wxcomapp URL with trailing slash 2024-07-18 23:21:45 +08:00
kody
f0e416455f 增加了使用阿里云进行语音识别的引擎 2024-07-15 22:03:31 +08:00
vision
f7a2c97943 Merge pull request #2153 from 6vision/update_linkaibot
support more file types.
2024-07-15 19:09:05 +08:00
6vision
993853757b Linkai bot supports more file types. 2024-07-15 18:57:58 +08:00
6vision
a3abfb987d update 2024-07-15 18:50:38 +08:00
Saboteur7
2711fa1b1b Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-07-08 19:00:03 +08:00
Saboteur7
1f7afaba07 fix: client cmd config bug 2024-07-08 18:57:27 +08:00
Clivia
e02c8bff81 PictureChange插件功能升级 2024-07-08 17:58:59 +08:00
Saboteur7
22391ba1a5 Update README.md 2024-07-05 15:45:54 +08:00
Saboteur7
a05781ec19 Merge pull request #2103 from 6vision/claude-3.5-sonnet
feat: support claude-3.5-sonnet model
2024-07-05 14:39:17 +08:00
Saboteur7
f898ed6a2a Merge branch 'master' into claude-3.5-sonnet 2024-07-05 14:32:45 +08:00
Saboteur7
e6d0a15b54 Merge pull request #2110 from He0607/新增高铁(火车)票查询插件
新增高铁(火车)票查询插件
2024-07-05 14:31:15 +08:00
Saboteur7
49cff026e2 Merge pull request #2113 from 6vision/update-0626
Update parameter descriptions for clarity
2024-07-05 14:26:33 +08:00
Saboteur7
08f0023cfd Merge pull request #2124 from 6vision/update_gemini_model
Update gemini 1.5model
2024-07-05 14:26:13 +08:00
Saboteur7
e311466ee6 Merge pull request #2128 from Maroon9/fix-docker-compose
fix:在docker-compose.yml文件中增加时区设置
2024-07-05 14:25:56 +08:00
wanxiangze
56789e68d7 fix:在docker-compose.yml文件中增加时区设置 2024-07-05 10:18:21 +08:00
6vision
87525bb383 update gemini model 2024-07-04 01:44:53 +08:00
6vision
bb2880191a update gemini model 2024-07-04 01:22:55 +08:00
6vision
4f1acf26d6 Merge branch 'update-0626' of https://github.com/6vision/chatgpt-on-wechat into update-0626 2024-06-27 21:11:14 +08:00
6vision
fc2d6b21ac update 2024-06-27 21:09:54 +08:00
zhayujie
b9e84fefbd Merge pull request #2114 from 6vision/fix_dingtalk_group_chat
fix: dingtalk channel group chat bug
2024-06-27 10:29:51 +08:00
6vision
91f5ffb2d9 Correct the log information 2024-06-26 22:34:35 +08:00
6vision
70ff2341cb fix:dingtalk channel group chat bug 2024-06-26 22:10:58 +08:00
vision
74eed93497 Merge branch 'zhayujie:master' into update-0626 2024-06-26 15:15:32 +08:00
6vision
d02e26c014 Update parameter descriptions for clarity 2024-06-26 15:14:29 +08:00
Wu_Cool
523cade7c3 新增高铁(火车)票查询插件 2024-06-26 09:13:40 +08:00
Wu_Cool
e22c183ca9 新增高铁(火车)票查询插件 2024-06-26 09:11:04 +08:00
vision
3afd99da30 Merge pull request #2106 from 6vision/fix_sensitive
Fix TypeError in config drag_sensitive function
2024-06-24 22:04:56 +08:00
6vision
f44979f983 Fix TypeError in config drag_sensitive function 2024-06-24 21:57:58 +08:00
6vision
095f9cc108 feat: support claude-3.5-sonnet model 2024-06-24 11:20:50 +08:00
zhayujie
1089076fce Merge pull request #2044 from Wang-zhechao/add-plugins-solitaire
添加微信接龙插件
2024-06-20 20:41:37 +08:00
Saboteur7
cad3b691a9 Update README.md 2024-06-20 16:09:19 +08:00
Saboteur7
bac21426d3 fix: minimax model list 2024-06-20 15:26:16 +08:00
Saboteur7
c4a35314cd Merge pull request #2071 from lmy668/master
feat#add minmax model
2024-06-20 15:21:41 +08:00
Saboteur7
7090722565 Merge branch 'master' into master 2024-06-20 15:21:20 +08:00
Saboteur7
6d972c7c18 Merge pull request #2046 from 6vision/update_mode_list
Update mode list
2024-06-20 15:09:05 +08:00
Saboteur7
6961a88feb Merge pull request #2060 from k8scat/remove-unused-import
remove unused import
2024-06-20 15:06:44 +08:00
6vision
c41ec13984 fix terminal channel 2024-06-15 16:34:32 +08:00
6vision
ca8e06e562 兼容符合openai请求格式的三方服务,根目录的config.json里增加配置"bot_type": "chatGPT" 2024-06-13 16:43:03 +08:00
limy26
200cd33a8e feat#add minmax model 2024-06-12 19:30:24 +08:00
6vision
1da7991c65 fix 2024-06-08 00:09:05 +08:00
K8sCat
fdfb7e369a remove unused import
Signed-off-by: K8sCat <k8scat@gmail.com>
2024-06-07 14:48:54 +08:00
6vision
c2b01cc957 Add configuration to plugin configuration template. 2024-06-05 17:10:08 +08:00
6vision
5de8e94bb4 update readme 2024-06-05 01:25:03 +08:00
6vision
7a2c15d912 Update model list 2024-06-05 00:44:08 +08:00
Wang Zhechao
70344dd214 添加微信接龙插件 2024-06-04 22:39:59 +08:00
zhayujie
405372d1a7 Merge pull request #1753 from MasterKeee/master
新增公众号的回复视频类型
2024-06-04 14:25:11 +08:00
Saboteur7
b8c5174da5 docs: xunfei voice comment 2024-06-04 13:49:44 +08:00
Saboteur7
1f6f9103d9 docs: update README.md 2024-06-04 12:50:59 +08:00
Saboteur7
6431487c7a fix: drag sensitive bug 2024-06-04 12:02:23 +08:00
Saboteur7
8b2d1189db Merge pull request #1999 from njnuko/voice-xunfei
add xunfei voice
2024-06-04 11:43:55 +08:00
Saboteur7
b777f27cb7 chore: remove some xunfei voice log 2024-06-04 11:42:05 +08:00
Saboteur7
b31c3b124a Merge pull request #1972 from Undertone0809/zeeland/add-logger-drag-sensitive
feat: add logger drag sensitive
2024-06-04 11:26:05 +08:00
Saboteur7
fa1e965fba feat: add dingtalk card switch 2024-06-04 11:23:45 +08:00
Saboteur7
91dc8b4d58 Merge pull request #1994 from baojingyu/feat-05-17
钉钉接入增加流式输出支持,语音、图片或富文本消息接收
2024-06-04 10:53:02 +08:00
Saboteur7
6d16ea8830 Update requirements.txt 2024-06-04 10:49:17 +08:00
Saboteur7
7db4253264 Update chat_channel.py 2024-06-04 10:47:56 +08:00
Saboteur7
4d2b7d9bf9 Update chat_channel.py 2024-06-04 10:47:05 +08:00
Saboteur7
8f6f4acb88 Update chat_channel.py 2024-06-04 10:43:19 +08:00
Saboteur7
f20d84cb37 Merge pull request #1809 from whw23/master
Azure OpenAI Dalle fix
2024-06-03 22:46:07 +08:00
Saboteur7
afbdf1d5d5 Merge pull request #2002 from 6vision/time_check
fix: time_check model
2024-06-03 22:40:01 +08:00
Haowei
bc8364d594 Merge branch 'zhayujie:master' into master 2024-05-25 23:34:47 +08:00
vision
c8d388f70f Merge pull request #2013 from 6vision/fix_baidu_voice
Changed sampling rate
2024-05-23 01:36:00 +08:00
6vision
be13cc3194 Changed sampling rate 2024-05-23 01:34:20 +08:00
vision
a46320e744 Merge pull request #2012 from 6vision/fix_issue_1959_
Fix issue 1959 wenxin模型返回报错
2024-05-22 21:45:20 +08:00
6vision
071709d263 fix: 1959-百度文心偶发报错336006 2024-05-22 16:01:46 +08:00
6vision
93a32ae5ff 修复模型请求异常时的bug 2024-05-22 15:57:22 +08:00
vision
eee96f226f Merge pull request #2005 from 6vision/fix_baidu_voice
fix: baidu voice bug
2024-05-21 22:38:54 +08:00
6vision
e19a8b479c fix: baidu voice bug 2024-05-21 22:32:35 +08:00
6vision
9ef459112e fix: time_check model 2024-05-20 20:37:00 +08:00
Haowei
e96474bd5c Merge branch 'zhayujie:master' into master 2024-05-20 16:53:02 +08:00
njnuko
6fed719e09 add Xunfei Voice
Signed-off-by: njnuko <njnuko@163.com>
2024-05-20 15:04:23 +08:00
zhayujie
99aac76618 docs: update readme 2024-05-18 19:03:17 +08:00
baojingyu
599f458201 Update plugins source.js add midjourney实现ai绘图的的插件 2024-05-17 15:38:19 +08:00
baojingyu
2f8099059c 修复chat_channel配置参数取值错误bug,优化dingtalk_channel回复打字机效果流式 AI卡片、dingtalk_message图片或富文本消息接收 2024-05-17 14:48:52 +08:00
zhayujie
e24f177832 Merge pull request #1993 from 6vision/fix_linkai_pconf
fix: linkai plugin config_template
2024-05-17 01:25:30 +08:00
6vision
48cc143e88 fix: linkai plugin config_template 2024-05-17 01:22:38 +08:00
zhayujie
b09b46c045 fix: summary switch bug 2024-05-14 17:48:18 +08:00
zhayujie
2c6583cc9c fix: summary switch bug 2024-05-14 17:26:10 +08:00
zhayujie
e381d1bfb8 feat: support gpt-4o model 2024-05-14 09:50:03 +08:00
zeeland
eac619d54f feat: add logger drag sensitive 2024-05-13 19:53:33 +08:00
zhayujie
a6ef3bc0ce fix: add channel login exception log 2024-05-08 12:54:13 +08:00
zhayujie
118122c541 docs: update README.md 2024-05-08 12:07:59 +08:00
zhayujie
bfdf33ac09 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-05-07 11:37:53 +08:00
zhayujie
fa3370df5b fix: image model check 2024-05-07 11:37:27 +08:00
zhayujie
f1e51672c5 Merge pull request #1944 from alvinsuDL/patch-1
Update README.md
2024-05-07 11:20:43 +08:00
alvinsuDL
91f97b2728 Update README.md 2024-05-07 11:16:41 +08:00
zhayujie
2c542e03fe Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-05-07 11:10:41 +08:00
zhayujie
71a11b4267 feat: support mj client config 2024-05-07 11:09:49 +08:00
zhayujie
ea642757db docs: update README.md 2024-05-06 22:19:49 +08:00
zhayujie
fb72b601aa fix: model config 2024-05-03 19:41:12 +08:00
zhayujie
27e507e744 fix: update client sdk version 2024-05-03 19:10:27 +08:00
zhayujie
4db19f816f feat: update service url 2024-05-03 14:10:07 +08:00
zhayujie
096d5776d1 feat: v1.6.0 verson update 2024-04-26 16:13:53 +08:00
zhayujie
3d799eb4d9 Merge pull request #1893 from uxfion/fix-openai-whisper
fix openai voice_to_text whisper
2024-04-26 15:37:34 +08:00
zhayujie
e4ac3afa4d Merge pull request #1849 from wayshall/kimi
feat: 增加moonshot api集成
2024-04-26 15:17:52 +08:00
zhayujie
d38e4eed5b Merge pull request #1904 from fatwang2/master
新增url解析逻辑,解决itchat中分享卡片无法解析的问题
2024-04-20 11:09:51 +08:00
fatwang2
97787fac91 新增url解析逻辑,解决itchat中分享卡片无法解析的问题 2024-04-20 00:48:33 +08:00
Lecter
b494ee2f1c fix openai voice_to_text whisper 2024-04-14 14:33:17 +08:00
zhayujie
31ac80a074 Merge pull request #1851 from wayshall/qwen-dashscope
feat: 通义千问使用新版的sdk实现
2024-04-09 16:06:33 +08:00
zhayujie
c8896450f6 fix: add warn log in glm 2024-04-09 15:57:59 +08:00
zhayujie
c662fa4c63 Merge pull request #1871 from cgnannan/master
修复 Issues #1868提到的elevenlabs sdk更新问题
2024-04-09 15:52:35 +08:00
zhayujie
db2ee802ca chore: log optimization 2024-04-09 15:35:18 +08:00
Haowei
d40e915e2b Merge branch 'zhayujie:master' into master 2024-04-09 11:31:57 +08:00
zhayujie
c0616e7efa Merge pull request #1881 from 6vision/feat_local
优化Hello插件。支持自定义欢迎语提示词以及为不同群设置不同的固定欢迎语
2024-04-09 10:46:22 +08:00
6vision
01660597e3 Merge branch 'feat_local' of git@github.com:6vision/chatgpt-on-wechat.git into feat_local 2024-04-08 23:09:08 +08:00
6vision
c5b549f450 优化hello插件 2024-04-08 23:06:35 +08:00
vision
802d8457bb Merge branch 'zhayujie:master' into feat_local 2024-04-08 23:05:39 +08:00
zhayujie
c3a3df67b0 Merge pull request #1847 from Yanyutin753/master
fix ReplyType.IMAGE 回复图片为空的BUG
2024-04-08 12:15:49 +08:00
6vision
5798aeb3cd Merge branch 'update-hello' of git@github.com:6vision/chatgpt-on-wechat.git into feat_local 2024-04-07 22:34:52 +08:00
6vision
cc81dd9172 Signed-off-by: 6vision <vision_wangpc@sina.com> 2024-04-07 22:31:08 +08:00
Haowei
44fdadda08 Merge branch 'zhayujie:master' into master 2024-04-07 14:54:48 +08:00
zhayujie
66a014150b fix: config update bug 2024-04-06 01:03:26 +08:00
zhayujie
1da596639f feat: update sdk version 2024-04-06 00:19:22 +08:00
zhayujie
76614ae9e5 fix: remote config load bug 2024-04-05 23:47:02 +08:00
cgnannan
6ddddffc0f update SDK version of elevenlabs and corresponding code snippets. 2024-04-01 06:26:39 +00:00
unknown
dd95f849d4 Merge branch 'master' of https://github.com/whw23/chatgpt-on-wechat 2024-03-30 01:08:07 +08:00
unknown
22c7f8fe9e add dall-e-2 retry_count limit 2024-03-30 01:07:52 +08:00
Haowei
3d47be1f49 Merge branch 'zhayujie:master' into master 2024-03-30 00:54:38 +08:00
weishao zeng
5e399c46b1 feat: 通义千问使用新版的sdk实现
现在项目使用的通义千问是旧版本的百炼sdk,
这里增加一个新版本sdk(dashscope)的实现
2024-03-27 19:12:39 +08:00
weishao zeng
38e1db7a37 feat: 增加moonshot api集成
moonshot本来可直接使用openai sdk,
但是要求openai sdk必须在1.0以上,与本项目冲突,
故现使用http接口对接的方式集成
2024-03-27 15:02:51 +08:00
Clivia
8309f7cdbe feat ReplyType.IMAGE 回复图片为空的BUG 2024-03-27 14:49:54 +08:00
zhayujie
b8cc62ae95 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-03-27 10:35:42 +08:00
zhayujie
c0eb433fa2 fix: remove unused import 2024-03-27 10:35:12 +08:00
zhayujie
7f857d66f6 docs: update README.md 2024-03-26 20:12:25 +08:00
zhayujie
93b14d38f4 Merge pull request #1837 from dividduang/master
blackroom
2024-03-26 16:10:18 +08:00
zhayujie
21825faab0 docs: update README.md 2024-03-26 16:01:05 +08:00
zhayujie
1fafd39298 fix: gemini session bug 2024-03-26 00:06:50 +08:00
WILMAR\dengjingren
23b750fc4f blackroom 2024-03-25 21:56:26 +08:00
zhayujie
90581c840d Merge pull request #1760 from xiexin12138/feature-优化智谱-AI-的命令操作
add feature 优化智谱 AI 的命令操作,使其支持重置会话
2024-03-25 21:43:23 +08:00
zhayujie
cac7a6228a fix: claude api optimize 2024-03-25 21:41:40 +08:00
zhayujie
674fbc3f69 Merge pull request #1810 from FB208/master
增加了claude api的调用方法
2024-03-25 20:42:59 +08:00
zhayujie
9577bf1cc7 Merge pull request #1724 from stx116/patch-1
Update xunfei_spark_bot.py修改,修改讯飞大语言模型至3.5版本
2024-03-25 15:31:48 +08:00
zhayujie
654ebe93e7 Merge branch 'master' into patch-1 2024-03-25 15:31:38 +08:00
zhayujie
ecb1b3c491 Merge pull request #1763 from JobsLee0/master
升级讯飞接口版本及协议,避免11200错误码问题[Update xunfei_spark_bot.py]
2024-03-25 15:29:12 +08:00
zhayujie
c3d1711edc Merge branch 'master' into master 2024-03-25 15:28:41 +08:00
zhayujie
c12c7f10f0 Merge pull request #1826 from Meng-de-Cao/master
Update xunfei_spark_bot.py
2024-03-25 15:26:53 +08:00
zhayujie
f71820bf4e Merge pull request #1787 from uxfion/edge-tts
feat: edge-tts
2024-03-25 15:24:14 +08:00
Haowei
748c53c774 Merge branch 'zhayujie:master' into master 2024-03-23 21:13:36 +08:00
zhayujie
b290a71bfb Merge pull request #1686 from xiaodonghsu/new
百度语音转写支持8000采样率, pcm_s16le编码, 单通道语音的组合
2024-03-21 15:47:20 +08:00
Saboteur7
3204c51eca Merge pull request #1412 from Yanyutin753/patch-6
Update source.json
2024-03-21 15:39:42 +08:00
Saboteur7
2c4b8a44dc Merge pull request #1816 from xywhnh/master
修复gemini 插件的两个问题
2024-03-21 15:34:42 +08:00
卡Q因
943aa05eaa Update xunfei_spark_bot.py
默认使用讯飞3.5模型
2024-03-20 21:22:15 +08:00
Haowei
d0fd36e7e1 Merge branch 'zhayujie:master' into master 2024-03-20 15:31:31 +08:00
zhayujie
f45ff5fd0a Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-03-20 12:08:07 +08:00
zhayujie
c22c7102d5 fix: no need to send when message is empty 2024-03-20 12:07:05 +08:00
Saboteur7
11ecfd1b41 Merge pull request #1819 from 13476573407/master
由于使用#scanp和#reloadp扫描插件时,当更新已存在的插件以后并不会实现重载更新后的插件
2024-03-20 12:04:01 +08:00
Saboteur7
798e30e5ac Merge pull request #1821 from gufei/fix-bug
修复两处BUG
2024-03-20 11:50:40 +08:00
13476573407
15e0702329 解决使用scanp重载时会重新生成godcmd的实例,导致auth权限被清空 2024-03-20 10:52:34 +08:00
13476573407
a2bc22c37d 由于使用#scanp和#reloadp扫描插件时,当更新插件以后并不会实现重载新的插件
所以取消了已载入的插件判断重载除Godcmd以外的所有插件来实现不需要重启项目即可更新插件
2024-03-18 14:40:01 +08:00
rowan.wu
8093fcc64c 修复两处BUG
1、类型定义中使用了驼峰,但其他位置使用的大写
2、微信channel中,发送IMAGE,多余了seek方法
2024-03-16 12:34:40 +08:00
熊伟(10007228)
800419e7cc 修复如下问题:
1.调用gemini api出现异常时没有向下游返回错误信息,后续处理流程可能要根据错误信息做相应补偿机制
2.修复特殊场景中出现索引越界导导致应用退出
2024-03-14 13:44:14 +08:00
FB208
a241dc6785 Update README.md 2024-03-12 13:09:55 +08:00
FB208
805bea0d5f 增加了claude api的调用方法 2024-03-12 10:39:51 +08:00
unknown
9d394adf24 1.修复Azure Openai Dalle请求 2.增加Azure Openai Dalle3 请求参数 3.将用于回复文字和回复Dalle3的Azure Openai资源分离开 2024-03-12 08:32:24 +08:00
Saboteur7
2074f27aff Merge pull request #1806 from goldfishh/master
disable plugin(tool) log printing
2024-03-10 13:28:32 +08:00
goldfishh
283ad48b86 disable plugin(tool) log printing 2024-03-10 13:11:45 +08:00
zhayujie
07e10a7943 Update README.md 2024-03-08 00:19:59 +08:00
zhayujie
2812a5026c Update README.md 2024-03-05 20:56:37 +08:00
Lecter
3a20461abf add edge-tts 2024-03-04 00:14:19 +08:00
Zhuoheng Lee
64ae3d1e21 Update xunfei_spark_bot.py
讯飞接口升级到v3.5版本,同时升级到wss协议,避免请求时出现11200错误码的问题
2024-02-21 14:14:19 +08:00
xiexin12138
a25d7ea65b add feature 优化智谱 AI 的命令操作,使其支持重置会话 2024-02-20 16:40:00 +08:00
zhayujie
74ebbdd761 fix: client resource usage bug 2024-02-19 13:32:32 +08:00
MasterKeee
a0427b569e 新增公众号的回复视频类型 2024-02-19 00:45:53 +08:00
zhayujie
5346dfdd8b feat: code tidying up 2024-02-05 12:21:50 +08:00
zhayujie
3ee4147285 Merge pull request #1723 from zRzRzRzRzRzRzR/master
支持ZhipuAI GLM系列模型和画图代码
2024-02-05 12:15:51 +08:00
zhayujie
c41e486bfc Update config.py 2024-02-05 12:15:28 +08:00
zhayujie
eda3ba92fd Merge branch 'master' into master 2024-02-05 12:14:26 +08:00
zhayujie
40255290b0 Merge pull request #1716 from wayshall/zhipu
feat: 增加智谱chatglm4模型支持
2024-02-05 12:05:07 +08:00
zhayujie
af5bc73dc0 feat: optimize consumer thread pool 2024-02-05 12:01:41 +08:00
zR
0247cd4c45 改善模型选择 2024-02-02 11:08:06 +08:00
stx116
916762cc8c Update xunfei_spark_bot.py
更新讯飞大语言模型到3.5版本
2024-02-01 15:18:56 +08:00
zR
d6fdf8ca2a 支持ZhipuAI GLM系列模型和画图代码 2024-02-01 11:31:56 +08:00
zhayujie
95708489c9 fix: wxcomapp user name 2024-01-31 16:24:29 +08:00
weishao zeng
ced0fa4608 feat: 增加智谱chatglm4模型支持 2024-01-30 10:17:53 +08:00
zhayujie
7e0fbd600f feat: add media send limit and interval 2024-01-29 11:46:00 +08:00
zhayujie
f33e4e0323 fix: close tool debug level 2024-01-27 11:08:44 +08:00
zhayujie
d0fd78497d Merge pull request #1680 from V-know/patch-1
Doc: 优化【服务器部署】
2024-01-26 16:29:11 +08:00
zhayujie
8045019603 feat: add 4-turbo-preview model 2024-01-26 16:21:11 +08:00
zhayujie
7d92b9435e Merge pull request #1678 from goldfishh/master
tool 0.5.0
2024-01-26 11:17:15 +08:00
zhayujie
1e0822703a fix: image num 2024-01-25 18:00:02 +08:00
zhayujie
0403ff88ef feat: image num limit 2024-01-25 15:45:24 +08:00
zhayujie
78376d591b fix: image limit 2024-01-25 15:40:52 +08:00
zhayujie
8e23d0df20 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-01-25 15:39:41 +08:00
zhayujie
9e281d20ab fix: image num limit 2024-01-25 15:34:59 +08:00
zhayujie
644bd4a106 Merge pull request #1698 from 6vision/6vision-patch-1
Update wework_message.py
2024-01-23 20:09:16 +08:00
zhayujie
7729e66a96 docs: update README.md 2024-01-23 20:01:55 +08:00
zhayujie
d67d6b7948 feat: knowledge base send file 2024-01-22 18:03:04 +08:00
vision
4c4a46bfbe Update wework_message.py 2024-01-22 13:38:11 +08:00
zhayujie
4536f9c177 feat: client mng 2024-01-19 14:38:14 +08:00
FMStereo
977d3bc02e 百度语音转写支持8000采样率, pcm_s16le编码, 单通道语音的组合 2024-01-18 12:46:18 +08:00
zhayujie
eae95dfef5 fix: api base bug 2024-01-17 18:25:57 +08:00
Cancellara
b67d4460ca Doc: 优化【服务器部署】
不必单独创建nohup.out文件
nohup 命令执行时会自动创建
2024-01-17 01:13:39 +08:00
goldfishh
3dea8311b1 change chatgpt_tool_hub version to 0.5.0 2024-01-16 23:39:40 +08:00
zhayujie
11f6e98874 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2024-01-16 23:22:10 +08:00
zhayujie
2609e595f4 fix: client host 2024-01-16 22:38:33 +08:00
zhayujie
ac6e41abc8 Merge pull request #1644 from PoseidonLi0514/master
Image generation supports custom endpoint
2024-01-16 22:35:57 +08:00
zhayujie
9c17e16d0a fix: optimize code format 2024-01-16 19:17:32 +08:00
goldfishh
55e9064307 tool ver0.5
1. 新增工具pure模式,支持单个工具调用
2. 新增消息转发工具:email, sms, wechat, 可以根据规则向其他平台发送消息
3. 替换visual-dl(更名为visual)实现,目前识别图片链接效果较好。
4. 修复了0.4版本大部分工具返回结果不可靠问题
2024-01-16 01:13:40 +08:00
zhayujie
91cabd7d49 Merge pull request #1628 from huiwenTT/dingdinggpt
添加语音发送消息
2024-01-15 22:45:46 +08:00
zhayujie
7456950530 Merge pull request #1658 from I-E-E-E/patch-1
fixed a typo
2024-01-15 22:41:12 +08:00
zhayujie
8fcdda625d Merge pull request #1675 from zhayujie/feat-client
feat: channel client
2024-01-15 22:37:53 +08:00
zhayujie
40a10ee926 Merge branch 'master' into feat-client 2024-01-15 22:37:47 +08:00
zhayujie
c3f7e2645c feat: channel client 2024-01-15 22:35:30 +08:00
I-E-E-E
b264af1892 fixed a typo 2024-01-08 17:51:15 +08:00
Haikui Yang
43e93e8e22 Update open_ai_image.py 2024-01-01 22:43:03 +08:00
Haikui Yang
d6c4789688 Merge branch 'zhayujie:master' into master 2024-01-01 22:42:10 +08:00
惠文
cb31ee6f01 Merge branch 'dingdinggpt' of github.com:huiwenTT/chatgpt-on-wechat-1 into dingdinggpt 2023-12-26 15:56:35 +08:00
huiwen
f7b694ac56 添加语音发送消息和修复上下文的关联 2023-12-26 14:48:54 +08:00
zhayujie
eb809055d4 Merge pull request #1559 from huiwenTT/dingdinggpt
钉钉机器人
2023-12-25 18:15:33 +08:00
zhayujie
78d9be82b2 fix: add gemini dependency 2023-12-19 11:47:33 +08:00
Haikui Yang
76a95c0226 Update open_ai_image.py 2023-12-17 19:50:06 +08:00
huiwen
d3ab8fb04a Merge branch 'dingdinggpt' of 47.98.110.173:/opt/python_app/gpt into dingdinggpt 2023-12-17 09:52:24 +08:00
huiwen
f7a0b63a00 Merge branch 'zhayujie:master' into dingdinggpt 2023-12-17 09:27:30 +08:00
huiwen
a21dd97786 钉钉app_id,变更为_client_id,和逻辑优化 2023-12-17 09:23:15 +08:00
zhayujie
04943c0bfa Update README.md 2023-12-16 01:11:05 +08:00
zhayujie
203d4d8bfb Update README.md 2023-12-15 19:16:13 +08:00
zhayujie
c049a619dc chore: remove useless code 2023-12-15 16:49:23 +08:00
zhayujie
cc1b14b607 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2023-12-15 14:44:54 +08:00
zhayujie
e04a12a8f4 Merge branch 'hanfangyuan4396-master' 2023-12-15 14:40:34 +08:00
zhayujie
a2c82bc583 Merge branch 'master' of https://github.com/hanfangyuan4396/chatgpt-on-wechat into hanfangyuan4396-master 2023-12-15 14:40:15 +08:00
zhayujie
b4dc382f7c Merge pull request #1598 from zhayujie/feat-gemini
feat: support gemini model
2023-12-15 14:24:26 +08:00
zhayujie
eca1892e2a fix: gemini no content bug 2023-12-15 14:23:36 +08:00
zhayujie
23a237074e feat: support gemini model 2023-12-15 10:19:48 +08:00
zhayujie
219e9eca4f Merge pull request #1595 from 6vision/master
企微优化
2023-12-14 12:00:28 +08:00
6vision
413e09fb9e 1、企微个人号支持文件和链接消息
2、修复企微个人号群名获取bug
2023-12-14 00:50:34 +08:00
zhayujie
3514c37e4c fix: railway fork does not need action 2023-12-13 20:57:04 +08:00
zhayujie
95260e303c fix: process markdown url in knowledge base 2023-12-11 20:48:13 +08:00
hanfangyuan4396
0cef34bdfa Merge branch 'zhayujie:master' into master 2023-12-09 19:41:01 +08:00
Han Fangyuan
9838979bbd refactor: update class name of qwen bot 2023-12-09 19:40:07 +08:00
Han Fangyuan
c8910b8e14 fix: set correct top_p params of ali qwen model 2023-12-09 19:26:11 +08:00
Han Fangyuan
207fa1d019 feat: hot reload conf of ali qwen model 2023-12-09 18:40:17 +08:00
zhayujie
be0bb591e7 fix: do not draw when text_to_image is empty 2023-12-09 17:12:08 +08:00
Han Fangyuan
bfacdb9c3b feat: support character description of ali qwen model 2023-12-09 12:39:09 +08:00
zhayujie
ae4077ed6c fix: config adjust 2023-12-08 14:29:14 +08:00
zhayujie
6eb3c90e18 feat: qwen model modify 2023-12-08 14:12:21 +08:00
zhayujie
8c2a53a504 Merge pull request #1573 from chazzjimel/master
add ali voice output
2023-12-08 13:34:54 +08:00
zhayujie
74db1e0308 Merge pull request #1537 from hanfangyuan4396/master
支持阿里云百炼平台通义千问模型
2023-12-08 13:27:52 +08:00
zhayujie
b9dfdcef3d Merge pull request #1577 from xyshell/patch-1
Update chat_gpt_bot.py retry APIConnectionError
2023-12-08 13:26:59 +08:00
zhayujie
9d4afeac31 feat: speech support app_code bind 2023-12-07 22:44:43 +08:00
zhayujie
14ae2f169a fix: hello plugin trigger app bug 2023-12-07 19:41:50 +08:00
You Xie
55df19142f Update chat_gpt_bot.py retry APIConnectionError 2023-12-06 02:27:22 -06:00
zhayujie
40fd545b2c fix: exit group optimize 2023-12-06 10:51:47 +08:00
zhayujie
95fb07343e Merge pull request #1570 from erayyym/master
adding features: 退群提醒
2023-12-06 10:42:15 +08:00
erayyym
4d87906559 增加了配置项
本地跑没有问题,用户打开这个功能需要在config.json加入  "group_chat_exit_group": true,

(但是不确定写的对不对,刚开始学cs哈哈,之前没搞过这个)
2023-12-05 13:18:42 -05:00
跃迁
6b30dced43 Merge branch 'zhayujie:master' into master 2023-12-06 00:44:18 +08:00
chazzjimel
293a03b7c8 add ali voice output
增加阿里云语音输出接口
2023-12-06 00:43:19 +08:00
zhayujie
c010549f17 Merge pull request #1563 from malsony/master
Update xunfei_spark_bot.py
2023-12-06 00:40:23 +08:00
zhayujie
cc0be22026 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2023-12-06 00:31:59 +08:00
zhayujie
e5ba26febe fix: tts voice base url 2023-12-06 00:31:31 +08:00
erayyym
36f9680eec adding features: 退群提醒
后面还打算想办法加用户自己退出的提醒,目前版本是可以在群主(且群主/管理员自己是bot)踢人时候发出提醒
2023-12-05 03:58:42 -05:00
zhayujie
f4f5be5b08 Create LICENSE 2023-12-04 11:14:55 +08:00
chazzjimel
d89b056886 add ali voice output
增加阿里云语音输出支持。
2023-12-03 18:19:03 +08:00
malsony
65424c7db9 Update xunfei_spark_bot.py
update API URL for v3.0 version of Xunfei Spark.
2023-12-01 16:09:15 +08:00
huiwen
32a8a847fc 修复小bug 2023-11-30 12:09:03 +08:00
zhayujie
88fb3dbf60 fix: generate break by bug 2023-11-30 11:51:04 +08:00
惠文
f6bee3aa58 新增钉钉机器人(Stream模式) 2023-11-30 10:41:34 +08:00
zhayujie
5f19f37dcb feat: hello plugin support app code 2023-11-29 23:15:31 +08:00
zhayujie
dd36d8ce9e Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2023-11-29 17:41:44 +08:00
zhayujie
865e4b5349 feat: hello plugin support system prompt 2023-11-29 17:41:14 +08:00
hanfangyuan4396
e70564752b Merge branch 'master' into master 2023-11-29 10:16:50 +08:00
zhayujie
6e0d2f9437 fix: remove unuse log and add plugin config in docker config 2023-11-28 16:29:32 +08:00
zhayujie
291f936097 Update README.md 2023-11-27 20:24:42 +08:00
Han Fangyuan
c1022feab8 fix: add tongyi model to model list 2023-11-25 10:06:10 +08:00
Han Fangyuan
8d07ba6332 fix: add tongyi type when init bridge 2023-11-19 23:00:18 +08:00
Han Fangyuan
4ce37f84e4 feat: support Tongyi Qwen model of alibaba 2023-11-19 22:42:44 +08:00
Clivia
854d613a81 Update source.json 2023-09-09 12:25:40 +08:00
443 changed files with 55689 additions and 7455 deletions

13
.flake8
View File

@@ -1,13 +0,0 @@
[flake8]
max-line-length = 176
select = E303,W293,W291,W292,E305,E231,E302
exclude =
.tox,
__pycache__,
*.pyc,
.env
venv/*
.venv/*
reports/*
dist/*
lib/*

View File

@@ -79,8 +79,6 @@ body:
description: |
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
options:
- wx(个人微信, itchat)
- wxy(个人微信, wechaty)
- wechatmp(公众号, 订阅号)
- wechatmp_service(公众号, 服务号)
- terminal

View File

@@ -19,6 +19,7 @@ env:
jobs:
build-and-push-image:
if: github.repository == 'zhayujie/chatgpt-on-wechat'
runs-on: ubuntu-latest
permissions:
contents: read

View File

@@ -19,6 +19,7 @@ env:
jobs:
build-and-push-image:
if: github.repository == 'zhayujie/chatgpt-on-wechat'
runs-on: ubuntu-latest
permissions:
contents: read

22
.gitignore vendored
View File

@@ -3,17 +3,19 @@
.vscode
.venv
.vs
.wechaty/
__pycache__/
venv*
*.pyc
python
config.json
QR.png
nohup.out
tmp
plugins.json
itchat.pkl
*.log
logs/
workspace
config.yaml
user_datas.pkl
chatgpt_tool_hub/
plugins/**/
@@ -29,4 +31,18 @@ plugins/banwords/lib/__pycache__
!plugins/hello
!plugins/role
!plugins/keyword
!plugins/linkai
!plugins/linkai
!plugins/agent
!plugins/cow_cli
client_config.json
ref/
**/.dev.vars
.cursor/
local/
node_modules/
# cow cli
dist/
build/
*.egg-info/
.cow.pid

View File

@@ -1,30 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements
- id: pretty-format-json
types: [text]
files: \.json(.template)?$
args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
- id: trailing-whitespace
exclude: '(\/|^)lib\/'
args: [ --markdown-linebreak-ext=md ]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
exclude: '(\/|^)lib\/'
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
exclude: '(\/|^)lib\/'
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: '(\/|^)lib\/'

908
README.md

File diff suppressed because it is too large Load Diff

3
agent/chat/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from agent.chat.service import ChatService
__all__ = ["ChatService"]

281
agent/chat/service.py Normal file
View File

@@ -0,0 +1,281 @@
"""
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
Translates agent events (message_update, message_end, tool_execution_end, etc.)
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
"""
import time
from typing import Callable, Optional
from common.log import logger
class ChatService:
"""
High-level service that runs an Agent for a given query and streams
the results as CHAT protocol chunks via a callback.
Usage:
svc = ChatService(agent_bridge)
svc.run(query, session_id, send_chunk_fn)
"""
def __init__(self, agent_bridge):
"""
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
"""
self.agent_bridge = agent_bridge
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
channel_type: str = ""):
"""
Run the agent for *query* and stream results back via *send_chunk_fn*.
The method blocks until the agent finishes. After it returns the SDK
will automatically send the final (streaming=false) message.
:param query: user query text
:param session_id: session identifier for agent isolation
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
:param channel_type: source channel (e.g. "web", "feishu") for persistence
"""
agent = self.agent_bridge.get_agent(session_id=session_id)
if agent is None:
raise RuntimeError("Failed to initialise agent for the session")
# Pass context metadata to model for downstream API requests
if hasattr(agent, 'model'):
agent.model.channel_type = channel_type or ""
agent.model.session_id = session_id or ""
# State shared between the event callback and this method
state = _StreamState()
def on_event(event: dict):
"""Translate agent events into CHAT protocol chunks."""
event_type = event.get("type")
data = event.get("data", {})
if event_type == "message_update":
# Incremental text delta
delta = data.get("delta", "")
if delta:
send_chunk_fn({
"chunk_type": "content",
"delta": delta,
"segment_id": state.segment_id,
})
elif event_type == "message_end":
# A content segment finished.
tool_calls = data.get("tool_calls", [])
if tool_calls:
# After tool_calls are executed the next content will be
# a new segment; collect tool results until turn_end.
state.pending_tool_results = []
elif event_type == "file_to_send":
url = data.get("url") or ""
if url:
fname = data.get("file_name") or "file"
ft = data.get("file_type") or "file"
if ft == "image":
link = f"![{fname}]({url})"
else:
link = f"[{fname}]({url})"
send_chunk_fn({
"chunk_type": "content",
"delta": "\n\n" + link + "\n\n",
"segment_id": state.segment_id,
})
# Remove url so the model won't repeat it in its reply
data.pop("url", None)
elif event_type == "tool_execution_start":
# Notify the client that a tool is about to run (with its input args)
tool_name = data.get("tool_name", "")
arguments = data.get("arguments", {})
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
tool_call_id = data.get("tool_call_id", tool_name)
state.pending_tool_arguments[tool_call_id] = arguments
send_chunk_fn({
"chunk_type": "tool_start",
"tool": tool_name,
"arguments": arguments,
})
elif event_type == "tool_execution_end":
tool_name = data.get("tool_name", "")
tool_call_id = data.get("tool_call_id", tool_name)
# Retrieve cached arguments from the matching tool_execution_start event
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
result = data.get("result", "")
status = data.get("status", "unknown")
execution_time = data.get("execution_time", 0)
elapsed_str = f"{execution_time:.2f}s"
# Serialise result to string if needed
if not isinstance(result, str):
import json
try:
result = json.dumps(result, ensure_ascii=False)
except Exception:
result = str(result)
tool_info = {
"name": tool_name,
"arguments": arguments,
"result": result,
"status": status,
"elapsed": elapsed_str,
}
if state.pending_tool_results is not None:
state.pending_tool_results.append(tool_info)
elif event_type == "turn_end":
has_tool_calls = data.get("has_tool_calls", False)
if has_tool_calls and state.pending_tool_results:
# Flush collected tool results as a single tool_calls chunk
send_chunk_fn({
"chunk_type": "tool_calls",
"tool_calls": state.pending_tool_results,
})
state.pending_tool_results = None
# Next content belongs to a new segment
state.segment_id += 1
# Run the agent with our event callback ---------------------------
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
from config import conf
max_context_turns = conf().get("agent_max_context_turns", 20)
# Get full system prompt with skills
full_system_prompt = agent.get_full_system_prompt()
# Create a copy of messages for this execution
with agent.messages_lock:
messages_copy = agent.messages.copy()
original_length = len(agent.messages)
from agent.protocol.agent_stream import AgentStreamExecutor
executor = AgentStreamExecutor(
agent=agent,
model=agent.model,
system_prompt=full_system_prompt,
tools=agent.tools,
max_turns=agent.max_steps,
on_event=on_event,
messages=messages_copy,
max_context_turns=max_context_turns,
)
try:
response = executor.run_stream(query)
except Exception:
# If executor cleared messages (context overflow), sync back
if len(executor.messages) == 0:
with agent.messages_lock:
agent.messages.clear()
logger.info("[ChatService] Cleared agent message history after executor recovery")
raise
# Sync executor messages back to agent (thread-safe).
# The executor may have trimmed context, making its list shorter than
# original_length. In that case we must replace entirely — just
# appending would leave stale pre-trim messages in agent.messages
# and cause the same trim to fire on every subsequent request.
with agent.messages_lock:
trimmed = len(executor.messages) < original_length
if trimmed:
# Context was trimmed: the executor appended the new user
# query *before* trimming, so the new messages (user +
# assistant + tools) sit at the tail of the trimmed list.
# We cannot simply slice at original_length (it exceeds the
# list length). Instead, count how many messages the
# executor added on top of the post-trim baseline.
#
# Timeline inside executor.run_stream:
# 1. messages had `original_length` items
# 2. append user query → original_length + 1
# 3. _trim_messages() → some smaller number (includes the
# user query because it belongs to the last turn)
# 4. LLM replies / tool calls appended
#
# The user query message is always the first message of the
# last turn (it cannot be trimmed away), so we locate it to
# find where "new" messages begin.
new_start = original_length # fallback
for idx in range(len(executor.messages) - 1, -1, -1):
msg = executor.messages[idx]
if msg.get("role") == "user":
content = msg.get("content", [])
is_user_query = False
if isinstance(content, list):
has_text = any(
isinstance(b, dict) and b.get("type") == "text"
for b in content
)
has_tool_result = any(
isinstance(b, dict) and b.get("type") == "tool_result"
for b in content
)
is_user_query = has_text and not has_tool_result
elif isinstance(content, str):
is_user_query = True
if is_user_query:
new_start = idx
break
new_messages = list(executor.messages[new_start:])
else:
new_messages = list(executor.messages[original_length:])
agent.messages = list(executor.messages)
# Persist new messages to SQLite so they survive restarts and
# can be queried via the HISTORY interface.
if new_messages:
self._persist_messages(session_id, list(new_messages), channel_type)
# Store executor reference for files_to_send access
agent.stream_executor = executor
# Execute post-process tools
agent._execute_post_process_tools()
logger.info(f"[ChatService] Agent run completed: session={session_id}")
@staticmethod
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
try:
from config import conf
if not conf().get("conversation_persistence", True):
return
except Exception:
pass
try:
from agent.memory import get_conversation_store
get_conversation_store().append_messages(
session_id, new_messages, channel_type=channel_type
)
except Exception as e:
logger.warning(
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
)
class _StreamState:
"""Mutable state shared between the event callback and the run method."""
def __init__(self):
self.segment_id: int = 0
# None means we are not accumulating tool results right now.
# A list means we are in the middle of a tool-execution phase.
self.pending_tool_results: Optional[list] = None
# Maps tool_call_id -> arguments captured from tool_execution_start,
# so that tool_execution_end can attach the correct input args.
self.pending_tool_arguments: dict = {}

23
agent/memory/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
"""
Memory module for AgentMesh
Provides both long-term memory (vector/keyword search) and short-term
conversation history persistence (SQLite).
"""
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
from agent.memory.embedding import create_embedding_provider
from agent.memory.conversation_store import ConversationStore, get_conversation_store
from agent.memory.summarizer import ensure_daily_memory_file
__all__ = [
'MemoryManager',
'MemoryConfig',
'get_default_memory_config',
'set_global_memory_config',
'create_embedding_provider',
'ConversationStore',
'get_conversation_store',
'ensure_daily_memory_file',
]

140
agent/memory/chunker.py Normal file
View File

@@ -0,0 +1,140 @@
"""
Text chunking utilities for memory
Splits text into chunks with token limits and overlap
"""
from __future__ import annotations
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class TextChunk:
"""Represents a text chunk with line numbers"""
text: str
start_line: int
end_line: int
class TextChunker:
"""Chunks text by line count with token estimation"""
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
"""
Initialize chunker
Args:
max_tokens: Maximum tokens per chunk
overlap_tokens: Overlap tokens between chunks
"""
self.max_tokens = max_tokens
self.overlap_tokens = overlap_tokens
# Rough estimation: ~4 chars per token for English/Chinese mixed
self.chars_per_token = 4
def chunk_text(self, text: str) -> List[TextChunk]:
"""
Chunk text into overlapping segments
Args:
text: Input text to chunk
Returns:
List of TextChunk objects
"""
if not text.strip():
return []
lines = text.split('\n')
chunks = []
max_chars = self.max_tokens * self.chars_per_token
overlap_chars = self.overlap_tokens * self.chars_per_token
current_chunk = []
current_chars = 0
start_line = 1
for i, line in enumerate(lines, start=1):
line_chars = len(line)
# If single line exceeds max, split it
if line_chars > max_chars:
# Save current chunk if exists
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
current_chunk = []
current_chars = 0
# Split long line into multiple chunks
for sub_chunk in self._split_long_line(line, max_chars):
chunks.append(TextChunk(
text=sub_chunk,
start_line=i,
end_line=i
))
start_line = i + 1
continue
# Check if adding this line would exceed limit
if current_chars + line_chars > max_chars and current_chunk:
# Save current chunk
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
# Start new chunk with overlap
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
current_chunk = overlap_lines + [line]
current_chars = sum(len(l) for l in current_chunk)
start_line = i - len(overlap_lines)
else:
# Add line to current chunk
current_chunk.append(line)
current_chars += line_chars
# Save last chunk
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=len(lines)
))
return chunks
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
"""Split a single long line into multiple chunks"""
chunks = []
for i in range(0, len(line), max_chars):
chunks.append(line[i:i + max_chars])
return chunks
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
"""Get last few lines that fit within target_chars for overlap"""
overlap = []
chars = 0
for line in reversed(lines):
line_chars = len(line)
if chars + line_chars > target_chars:
break
overlap.insert(0, line)
chars += line_chars
return overlap
def chunk_markdown(self, text: str) -> List[TextChunk]:
"""
Chunk markdown text while respecting structure
(For future enhancement: respect markdown sections)
"""
return self.chunk_text(text)

122
agent/memory/config.py Normal file
View File

@@ -0,0 +1,122 @@
"""
Memory configuration module
Provides global memory configuration with simplified workspace structure
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path
def _default_workspace():
"""Get default workspace path with proper Windows support"""
from common.utils import expand_path
return expand_path("~/cow")
@dataclass
class MemoryConfig:
"""Configuration for memory storage and search"""
# Storage paths (default: ~/cow)
workspace_root: str = field(default_factory=_default_workspace)
# Embedding config
embedding_provider: str = "openai" # "openai" | "local"
embedding_model: str = "text-embedding-3-small"
embedding_dim: int = 1536
# Chunking config
chunk_max_tokens: int = 500
chunk_overlap_tokens: int = 50
# Search config
max_results: int = 10
min_score: float = 0.1
# Hybrid search weights
vector_weight: float = 0.7
keyword_weight: float = 0.3
# Memory sources
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
# Sync config
enable_auto_sync: bool = True
sync_on_search: bool = True
def get_workspace(self) -> Path:
"""Get workspace root directory"""
return Path(self.workspace_root)
def get_memory_dir(self) -> Path:
"""Get memory files directory"""
return self.get_workspace() / "memory"
def get_db_path(self) -> Path:
"""Get SQLite database path for long-term memory index"""
index_dir = self.get_memory_dir() / "long-term"
index_dir.mkdir(parents=True, exist_ok=True)
return index_dir / "index.db"
def get_skills_dir(self) -> Path:
"""Get skills directory"""
return self.get_workspace() / "skills"
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
"""
Get workspace directory for an agent
Args:
agent_name: Optional agent name (not used in current implementation)
Returns:
Path to workspace directory
"""
workspace = self.get_workspace()
# Ensure workspace directory exists
workspace.mkdir(parents=True, exist_ok=True)
return workspace
# Global memory configuration
_global_memory_config: Optional[MemoryConfig] = None
def get_default_memory_config() -> MemoryConfig:
"""
Get the global memory configuration.
If not set, returns a default configuration.
Returns:
MemoryConfig instance
"""
global _global_memory_config
if _global_memory_config is None:
_global_memory_config = MemoryConfig()
return _global_memory_config
def set_global_memory_config(config: MemoryConfig):
"""
Set the global memory configuration.
This should be called before creating any MemoryManager instances.
Args:
config: MemoryConfig instance to use globally
Example:
>>> from agent.memory import MemoryConfig, set_global_memory_config
>>> config = MemoryConfig(
... workspace_root="~/my_agents",
... embedding_provider="openai",
... vector_weight=0.8
... )
>>> set_global_memory_config(config)
"""
global _global_memory_config
_global_memory_config = config

View File

@@ -0,0 +1,618 @@
"""
Conversation history persistence using SQLite.
Design:
- sessions table: per-session metadata (channel_type, last_active, msg_count)
- messages table: individual messages stored as JSON, append-only
- Pruning: age-based only (sessions not updated within N days are deleted)
- Thread-safe via a single in-process lock
Storage path: ~/cow/sessions/conversations.db
"""
from __future__ import annotations
import json
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from common.log import logger
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
channel_type TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
last_active INTEGER NOT NULL,
msg_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
seq INTEGER NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE (session_id, seq)
);
CREATE INDEX IF NOT EXISTS idx_messages_session
ON messages (session_id, seq);
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
ON sessions (last_active);
"""
# Migration: add channel_type column to existing databases that predate it.
_MIGRATION_ADD_CHANNEL_TYPE = """
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
"""
DEFAULT_MAX_AGE_DAYS: int = 30
def _is_visible_user_message(content: Any) -> bool:
"""
Return True when a user-role message represents actual user input
(not an internal tool_result injected by the agent loop).
"""
if isinstance(content, str):
return bool(content.strip())
if isinstance(content, list):
return any(
isinstance(b, dict) and b.get("type") == "text"
for b in content
)
return False
def _extract_display_text(content: Any) -> str:
"""
Extract the human-readable text portion from a message content value.
Returns an empty string for tool_use / tool_result blocks.
"""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
"""
Extract tool_use blocks from an assistant message content.
Returns a list of {name, arguments} dicts (result filled in later).
"""
if not isinstance(content, list):
return []
return [
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
for b in content
if isinstance(b, dict) and b.get("type") == "tool_use"
]
def _extract_tool_results(content: Any) -> Dict[str, str]:
"""
Extract tool_result blocks from a user message, keyed by tool_use_id.
"""
if not isinstance(content, list):
return {}
results = {}
for b in content:
if not isinstance(b, dict) or b.get("type") != "tool_result":
continue
tool_id = b.get("tool_use_id", "")
result_content = b.get("content", "")
if isinstance(result_content, list):
result_content = "\n".join(
rb.get("text", "") for rb in result_content
if isinstance(rb, dict) and rb.get("type") == "text"
)
results[tool_id] = str(result_content)
return results
def _group_into_display_turns(
rows: List[tuple],
) -> List[Dict[str, Any]]:
"""
Convert raw (role, content_json, created_at) DB rows into display turns.
One display turn = one visible user message + one merged assistant reply.
All intermediate assistant messages (those carrying tool_use) and the final
assistant text reply produced for the same user query are collapsed into a
single assistant turn, exactly matching the live SSE rendering where tools
and the final answer appear inside the same bubble.
Grouping rules:
- A visible user message starts a new group.
- tool_result user messages are internal; their content is attached to the
matching tool_use entry via tool_use_id and they never become own turns.
- All assistant messages within a group are merged:
* tool_use blocks → tool_calls list (result filled from tool_results)
* text blocks → last non-empty text becomes the display content
"""
# ------------------------------------------------------------------ #
# Pass 1: split rows into groups, each starting with a visible user msg
# ------------------------------------------------------------------ #
# group = (user_row | None, [subsequent_rows])
# user_row: (content, created_at)
groups: List[tuple] = []
cur_user: Optional[tuple] = None
cur_rest: List[tuple] = []
started = False
for role, raw_content, created_at in rows:
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
if role == "user" and _is_visible_user_message(content):
if started:
groups.append((cur_user, cur_rest))
cur_user = (content, created_at)
cur_rest = []
started = True
else:
cur_rest.append((role, content, created_at))
if started:
groups.append((cur_user, cur_rest))
# ------------------------------------------------------------------ #
# Pass 2: build display turns from each group
# ------------------------------------------------------------------ #
turns: List[Dict[str, Any]] = []
for user_row, rest in groups:
# User turn
if user_row:
content, created_at = user_row
text = _extract_display_text(content)
if text:
turns.append({"role": "user", "content": text, "created_at": created_at})
# Collect all tool_calls and tool_results from the rest of the group
all_tool_calls: List[Dict[str, Any]] = []
tool_results: Dict[str, str] = {}
final_text = ""
final_ts: Optional[int] = None
for role, content, created_at in rest:
if role == "user":
tool_results.update(_extract_tool_results(content))
elif role == "assistant":
tcs = _extract_tool_calls(content)
all_tool_calls.extend(tcs)
t = _extract_display_text(content)
if t:
final_text = t
final_ts = created_at
# Attach tool results to their matching tool_call entries
for tc in all_tool_calls:
tc["result"] = tool_results.get(tc.get("id", ""), "")
if final_text or all_tool_calls:
turns.append({
"role": "assistant",
"content": final_text,
"tool_calls": all_tool_calls,
"created_at": final_ts or (user_row[1] if user_row else 0),
})
return turns
class ConversationStore:
"""
SQLite-backed store for per-session conversation history.
Usage:
store = ConversationStore(db_path)
store.append_messages("user_123", new_messages, channel_type="feishu")
msgs = store.load_messages("user_123", max_turns=30)
"""
def __init__(self, db_path: Path):
self._db_path = db_path
self._lock = threading.Lock()
self._init_db()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def load_messages(
self,
session_id: str,
max_turns: int = 30,
) -> List[Dict[str, Any]]:
"""
Load the most recent messages for a session, for injection into the LLM.
ALL message types (user text, assistant tool_use, tool_result) are returned
in their original JSON form so the LLM can reconstruct the full context.
max_turns is a *visible-turn* count: we count only user messages whose
content is actual user text (not tool_result blocks). This prevents
tool-heavy sessions from exhausting the turn budget prematurely.
Args:
session_id: Unique session identifier.
max_turns: Maximum number of visible user-assistant turns to keep.
Returns:
Chronologically ordered list of message dicts (role, content).
"""
with self._lock:
conn = self._connect()
try:
rows = conn.execute(
"""
SELECT seq, role, content
FROM messages
WHERE session_id = ?
ORDER BY seq DESC
""",
(session_id,),
).fetchall()
finally:
conn.close()
if not rows:
return []
# Walk newest-to-oldest counting *visible* user turns (actual user text,
# not tool_result injections). Record the seq of every visible user
# message so we can find a clean cut point later.
visible_turn_seqs: List[int] = [] # newest first
for seq, role, raw_content in rows:
if role != "user":
continue
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
if _is_visible_user_message(content):
visible_turn_seqs.append(seq)
# Determine the seq of the oldest visible user message we want to keep.
# If the total turns fit within max_turns, keep everything.
if len(visible_turn_seqs) <= max_turns:
cutoff_seq = None # keep all
else:
# The Nth visible user message (0-indexed) is the oldest we keep.
cutoff_seq = visible_turn_seqs[max_turns - 1]
# Build result in chronological order, starting from cutoff.
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
# never mid-group, so tool_use / tool_result pairs are always complete.
result = []
for seq, role, raw_content in reversed(rows):
if cutoff_seq is not None and seq < cutoff_seq:
continue
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
result.append({"role": role, "content": content})
return result
def append_messages(
self,
session_id: str,
messages: List[Dict[str, Any]],
channel_type: str = "",
) -> None:
"""
Append new messages to a session's history.
Seq numbers continue from the session's current maximum, so
concurrent callers on distinct sessions never collide.
Args:
session_id: Unique session identifier.
messages: List of message dicts to append.
channel_type: Source channel (e.g. "feishu", "web", "wechat").
Only written on session creation; ignored on update.
"""
if not messages:
return
now = int(time.time())
with self._lock:
conn = self._connect()
try:
with conn:
# INSERT OR IGNORE creates the row on first visit;
# the UPDATE always refreshes last_active.
# Avoids ON CONFLICT...DO UPDATE (requires SQLite >= 3.24).
conn.execute(
"""
INSERT OR IGNORE INTO sessions
(session_id, channel_type, created_at, last_active, msg_count)
VALUES (?, ?, ?, ?, 0)
""",
(session_id, channel_type, now, now),
)
conn.execute(
"UPDATE sessions SET last_active = ? WHERE session_id = ?",
(now, session_id),
)
# Determine starting seq for the new batch.
row = conn.execute(
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
(session_id,),
).fetchone()
next_seq = row[0] + 1
for msg in messages:
role = msg.get("role", "")
content = json.dumps(
msg.get("content", ""), ensure_ascii=False
)
conn.execute(
"""
INSERT OR IGNORE INTO messages
(session_id, seq, role, content, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(session_id, next_seq, role, content, now),
)
next_seq += 1
conn.execute(
"""
UPDATE sessions
SET msg_count = (
SELECT COUNT(*) FROM messages WHERE session_id = ?
)
WHERE session_id = ?
""",
(session_id, session_id),
)
finally:
conn.close()
def clear_session(self, session_id: str) -> None:
"""Delete all messages and the session record for a given session_id."""
with self._lock:
conn = self._connect()
try:
with conn:
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
conn.execute(
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
)
finally:
conn.close()
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
"""
Delete sessions that have not been active within max_age_days.
Args:
max_age_days: Override the default retention period.
Returns:
Number of sessions deleted.
"""
try:
from config import conf
max_age = max_age_days or conf().get(
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
)
except Exception:
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
cutoff = int(time.time()) - max_age * 86400
deleted = 0
with self._lock:
conn = self._connect()
try:
with conn:
stale = conn.execute(
"SELECT session_id FROM sessions WHERE last_active < ?",
(cutoff,),
).fetchall()
for (sid,) in stale:
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (sid,)
)
conn.execute(
"DELETE FROM sessions WHERE session_id = ?", (sid,)
)
deleted += 1
finally:
conn.close()
if deleted:
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
return deleted
def load_history_page(
self,
session_id: str,
page: int = 1,
page_size: int = 20,
) -> Dict[str, Any]:
"""
Load a page of conversation history for UI display, grouped into turns.
Each "turn" maps to one of:
- A user message (role="user", content=str)
- An assistant message (role="assistant", content=str,
tool_calls=[{name, arguments, result}] when tools were used)
Internal tool_result user messages are merged into the preceding
assistant entry's tool_calls list and never appear as standalone items.
Pages are numbered from 1 (most recent). Messages within a page are
returned in chronological order.
Returns:
{
"messages": [
{
"role": "user" | "assistant",
"content": str,
"tool_calls": [...], # assistant only, may be []
"created_at": int,
},
...
],
"total": <visible turn count>,
"page": <current page>,
"page_size": <page_size>,
"has_more": bool,
}
"""
page = max(1, page)
with self._lock:
conn = self._connect()
try:
rows = conn.execute(
"""
SELECT role, content, created_at
FROM messages
WHERE session_id = ?
ORDER BY seq ASC
""",
(session_id,),
).fetchall()
finally:
conn.close()
visible = _group_into_display_turns(rows)
total = len(visible)
offset = (page - 1) * page_size
page_items = list(reversed(visible))[offset: offset + page_size]
page_items = list(reversed(page_items))
return {
"messages": page_items,
"total": total,
"page": page,
"page_size": page_size,
"has_more": offset + page_size < total,
}
def get_stats(self) -> Dict[str, Any]:
"""Return basic stats keyed by channel_type, for monitoring."""
with self._lock:
conn = self._connect()
try:
total_sessions = conn.execute(
"SELECT COUNT(*) FROM sessions"
).fetchone()[0]
total_messages = conn.execute(
"SELECT COUNT(*) FROM messages"
).fetchone()[0]
by_channel = conn.execute(
"""
SELECT channel_type, COUNT(*) as cnt
FROM sessions
GROUP BY channel_type
ORDER BY cnt DESC
"""
).fetchall()
return {
"total_sessions": total_sessions,
"total_messages": total_messages,
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
}
finally:
conn.close()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _init_db(self) -> None:
self._db_path.parent.mkdir(parents=True, exist_ok=True)
conn = self._connect()
try:
conn.executescript(_DDL)
conn.commit()
self._migrate(conn)
finally:
conn.close()
def _migrate(self, conn: sqlite3.Connection) -> None:
"""Apply incremental schema migrations on existing databases."""
cols = {
row[1]
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
}
if "channel_type" not in cols:
try:
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
conn.commit()
logger.info("[ConversationStore] Migrated: added channel_type column")
except Exception as e:
logger.warning(f"[ConversationStore] Migration failed: {e}")
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(str(self._db_path), timeout=10)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_store_instance: Optional[ConversationStore] = None
_store_lock = threading.Lock()
def get_conversation_store() -> ConversationStore:
"""
Return the process-wide ConversationStore singleton.
Reuses the long-term memory database so the project stays with a single
SQLite file: ~/cow/memory/long-term/index.db
The conversation tables (sessions / messages) are separate from the
memory tables (memory_chunks / file_metadata) — no conflicts.
"""
global _store_instance
if _store_instance is not None:
return _store_instance
with _store_lock:
if _store_instance is not None:
return _store_instance
try:
from agent.memory.config import get_default_memory_config
db_path = get_default_memory_config().get_db_path()
except Exception:
from common.utils import expand_path
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
_store_instance = ConversationStore(db_path)
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
return _store_instance

167
agent/memory/embedding.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Embedding providers for memory
Supports OpenAI and local embedding models
"""
import hashlib
from abc import ABC, abstractmethod
from typing import List, Optional
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
pass
@property
@abstractmethod
def dimensions(self) -> int:
"""Get embedding dimensions"""
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
extra_headers: Optional extra headers to include in API requests
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
self.extra_headers = extra_headers or {}
# Validate API key
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
def _call_api(self, input_data):
"""Call OpenAI embedding API using requests"""
import requests
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
data = {
"input": input_data,
"model": self.model
}
try:
response = requests.post(url, headers=headers, json=data, timeout=5)
response.raise_for_status()
return response.json()
except requests.exceptions.ConnectionError as e:
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
except requests.exceptions.Timeout as e:
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
elif e.response.status_code == 429:
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
else:
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
result = self._call_api(text)
return result["data"][0]["embedding"]
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
result = self._call_api(texts)
return [item["embedding"] for item in result["data"]]
@property
def dimensions(self) -> int:
return self._dimensions
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
return self.cache.get(key)
def put(self, text: str, provider: str, model: str, embedding: List[float]):
"""Cache embedding"""
key = self._compute_key(text, provider, model)
self.cache[key] = embedding
@staticmethod
def _compute_key(text: str, provider: str, model: str) -> str:
"""Compute cache key"""
content = f"{provider}:{model}:{text}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def clear(self):
"""Clear cache"""
self.cache.clear()
def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
If initialization fails, caller should fall back to keyword-only search.
Args:
provider: Provider name ("openai" or "linkai")
model: Model name (default: text-embedding-3-small)
api_key: API key (required)
api_base: API base URL
extra_headers: Optional extra headers to include in API requests
Returns:
EmbeddingProvider instance
Raises:
ValueError: If provider is unsupported or api_key is missing
"""
if provider not in ("openai", "linkai"):
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)

527
agent/memory/manager.py Normal file
View File

@@ -0,0 +1,527 @@
"""
Memory manager for AgentMesh
Provides high-level interface for memory operations
"""
import os
from typing import List, Optional, Dict, Any
from pathlib import Path
import hashlib
from datetime import datetime, timedelta
from agent.memory.config import MemoryConfig, get_default_memory_config
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
from agent.memory.chunker import TextChunker
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
class MemoryManager:
"""
Memory manager with hybrid search capabilities
Provides long-term memory for agents with vector and keyword search
"""
def __init__(
self,
config: Optional[MemoryConfig] = None,
embedding_provider: Optional[EmbeddingProvider] = None,
llm_model: Optional[Any] = None
):
"""
Initialize memory manager
Args:
config: Memory configuration (uses global config if not provided)
embedding_provider: Custom embedding provider (optional)
llm_model: LLM model for summarization (optional)
"""
self.config = config or get_default_memory_config()
# Initialize storage
db_path = self.config.get_db_path()
self.storage = MemoryStorage(db_path)
# Initialize chunker
self.chunker = TextChunker(
max_tokens=self.config.chunk_max_tokens,
overlap_tokens=self.config.chunk_overlap_tokens
)
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
self.embedding_provider = None
if embedding_provider:
self.embedding_provider = embedding_provider
else:
# Try OpenAI first
try:
api_key = os.environ.get('OPENAI_API_KEY')
api_base = os.environ.get('OPENAI_API_BASE')
if api_key:
self.embedding_provider = create_embedding_provider(
provider="openai",
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
# Fallback to LinkAI
if self.embedding_provider is None:
try:
linkai_key = os.environ.get('LINKAI_API_KEY')
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
if linkai_key:
from common.utils import get_cloud_headers
cloud_headers = get_cloud_headers(linkai_key)
cloud_headers.pop("Authorization", None)
self.embedding_provider = create_embedding_provider(
provider="linkai",
model=self.config.embedding_model,
api_key=linkai_key,
api_base=f"{linkai_base}/v1",
extra_headers=cloud_headers,
)
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
if self.embedding_provider is None:
from common.log import logger
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
# Initialize memory flush manager
workspace_dir = self.config.get_workspace()
self.flush_manager = MemoryFlushManager(
workspace_dir=workspace_dir,
llm_model=llm_model
)
# Ensure workspace directories exist
self._init_workspace()
self._dirty = False
def _init_workspace(self):
"""Initialize workspace directories"""
memory_dir = self.config.get_memory_dir()
memory_dir.mkdir(parents=True, exist_ok=True)
# Create default memory files
workspace_dir = self.config.get_workspace()
create_memory_files_if_needed(workspace_dir)
async def search(
self,
query: str,
user_id: Optional[str] = None,
max_results: Optional[int] = None,
min_score: Optional[float] = None,
include_shared: bool = True
) -> List[SearchResult]:
"""
Search memory with hybrid search (vector + keyword)
Args:
query: Search query
user_id: User ID for scoped search
max_results: Maximum results to return
min_score: Minimum score threshold
include_shared: Include shared memories
Returns:
List of search results sorted by relevance
"""
max_results = max_results or self.config.max_results
min_score = min_score or self.config.min_score
# Determine scopes
scopes = []
if include_shared:
scopes.append("shared")
if user_id:
scopes.append("user")
if not scopes:
return []
# Sync if needed
if self.config.sync_on_search and self._dirty:
await self.sync()
# Perform vector search (if embedding provider available)
vector_results = []
if self.embedding_provider:
try:
from common.log import logger
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector(
query_embedding=query_embedding,
user_id=user_id,
scopes=scopes,
limit=max_results * 2 # Get more candidates for merging
)
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] Vector search failed: {e}")
# Perform keyword search
keyword_results = self.storage.search_keyword(
query=query,
user_id=user_id,
scopes=scopes,
limit=max_results * 2
)
from common.log import logger
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
# Merge results
merged = self._merge_results(
vector_results,
keyword_results,
self.config.vector_weight,
self.config.keyword_weight
)
# Filter by min score and limit
filtered = [r for r in merged if r.score >= min_score]
return filtered[:max_results]
async def add_memory(
self,
content: str,
user_id: Optional[str] = None,
scope: str = "shared",
source: str = "memory",
path: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""
Add new memory content
Args:
content: Memory content
user_id: User ID for user-scoped memory
scope: Memory scope ("shared", "user", "session")
source: Memory source ("memory" or "session")
path: File path (auto-generated if not provided)
metadata: Additional metadata
"""
if not content.strip():
return
# Generate path if not provided
if not path:
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
if user_id and scope == "user":
path = f"memory/users/{user_id}/memory_{content_hash}.md"
else:
path = f"memory/shared/memory_{content_hash}.md"
# Chunk content
chunks = self.chunker.chunk_text(content)
# Generate embeddings (if provider available)
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
# No embeddings, just use None
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
user_id=user_id,
scope=scope,
source=source,
path=path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=metadata
))
# Save to storage
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
file_hash = MemoryStorage.compute_hash(content)
self.storage.update_file_metadata(
path=path,
source=source,
file_hash=file_hash,
mtime=int(os.path.getmtime(__file__)), # Use current time
size=len(content)
)
async def sync(self, force: bool = False):
"""
Synchronize memory from files
Args:
force: Force full reindex
"""
memory_dir = self.config.get_memory_dir()
workspace_dir = self.config.get_workspace()
# Scan MEMORY.md (workspace root)
memory_file = Path(workspace_dir) / "MEMORY.md"
if memory_file.exists():
await self._sync_file(memory_file, "memory", "shared", None)
# Scan memory directory (including daily summaries)
if memory_dir.exists():
for file_path in memory_dir.rglob("*.md"):
# Determine scope and user_id from path
rel_path = file_path.relative_to(workspace_dir)
parts = rel_path.parts
# Check if it's in daily summary directory
if "daily" in parts:
# Daily summary files
if "users" in parts or len(parts) > 3:
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
user_idx = parts.index("daily") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared daily summary: memory/daily/2024-01-29.md
user_id = None
scope = "shared"
elif "users" in parts:
# User-scoped memory
user_idx = parts.index("users") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared memory
user_id = None
scope = "shared"
await self._sync_file(file_path, "memory", scope, user_id)
self._dirty = False
async def _sync_file(
self,
file_path: Path,
source: str,
scope: str,
user_id: Optional[str]
):
"""Sync a single file"""
# Compute file hash
content = file_path.read_text(encoding='utf-8')
file_hash = MemoryStorage.compute_hash(content)
# Get relative path
workspace_dir = self.config.get_workspace()
rel_path = str(file_path.relative_to(workspace_dir))
# Check if file changed
stored_hash = self.storage.get_file_hash(rel_path)
if stored_hash == file_hash:
return # No changes
# Delete old chunks
self.storage.delete_by_path(rel_path)
# Chunk and embed
chunks = self.chunker.chunk_text(content)
if not chunks:
return
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
user_id=user_id,
scope=scope,
source=source,
path=rel_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=None
))
# Save
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
stat = file_path.stat()
self.storage.update_file_metadata(
path=rel_path,
source=source,
file_hash=file_hash,
mtime=int(stat.st_mtime),
size=stat.st_size
)
def flush_memory(
self,
messages: list,
user_id: Optional[str] = None,
reason: str = "threshold",
max_messages: int = 10,
) -> bool:
"""
Flush conversation summary to daily memory file.
Args:
messages: Conversation message list
user_id: Optional user ID
reason: "threshold" | "overflow" | "daily_summary"
max_messages: Max recent messages to include (0 = all)
Returns:
True if content was written
"""
success = self.flush_manager.flush_from_messages(
messages=messages,
user_id=user_id,
reason=reason,
max_messages=max_messages,
)
if success:
self._dirty = True
return success
def get_status(self) -> Dict[str, Any]:
"""Get memory status"""
stats = self.storage.get_stats()
return {
'chunks': stats['chunks'],
'files': stats['files'],
'workspace': str(self.config.get_workspace()),
'dirty': self._dirty,
'embedding_enabled': self.embedding_provider is not None,
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
}
def mark_dirty(self):
"""Mark memory as dirty (needs sync)"""
self._dirty = True
def close(self):
"""Close memory manager and release resources"""
self.storage.close()
# Helper methods
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
"""Generate unique chunk ID"""
content = f"{path}:{start_line}:{end_line}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
@staticmethod
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
"""
Compute temporal decay multiplier for dated memory files.
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
Daily files like memory/2025-03-01.md decay based on age.
Formula: multiplier = exp(-ln2/half_life * age_in_days)
"""
import re
import math
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
if not match:
return 1.0 # evergreen: MEMORY.md, non-dated files
try:
file_date = datetime(
int(match.group(1)), int(match.group(2)), int(match.group(3))
)
age_days = (datetime.now() - file_date).days
if age_days <= 0:
return 1.0
decay_lambda = math.log(2) / half_life_days
return math.exp(-decay_lambda * age_days)
except (ValueError, OverflowError):
return 1.0
def _merge_results(
self,
vector_results: List[SearchResult],
keyword_results: List[SearchResult],
vector_weight: float,
keyword_weight: float
) -> List[SearchResult]:
"""Merge vector and keyword search results with temporal decay for dated files"""
merged_map = {}
for result in vector_results:
key = (result.path, result.start_line, result.end_line)
merged_map[key] = {
'result': result,
'vector_score': result.score,
'keyword_score': 0.0
}
for result in keyword_results:
key = (result.path, result.start_line, result.end_line)
if key in merged_map:
merged_map[key]['keyword_score'] = result.score
else:
merged_map[key] = {
'result': result,
'vector_score': 0.0,
'keyword_score': result.score
}
merged_results = []
for entry in merged_map.values():
combined_score = (
vector_weight * entry['vector_score'] +
keyword_weight * entry['keyword_score']
)
# Apply temporal decay for dated memory files
result = entry['result']
decay = self._compute_temporal_decay(result.path)
combined_score *= decay
merged_results.append(SearchResult(
path=result.path,
start_line=result.start_line,
end_line=result.end_line,
score=combined_score,
snippet=result.snippet,
source=result.source,
user_id=result.user_id
))
merged_results.sort(key=lambda r: r.score, reverse=True)
return merged_results

167
agent/memory/service.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Memory service for handling memory query operations via cloud protocol.
Provides a unified interface for listing and reading memory files,
callable from the cloud client (LinkAI) or a future web console.
Memory file layout (under workspace_root):
MEMORY.md -> type: global
memory/2026-02-20.md -> type: daily
"""
import os
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from common.log import logger
class MemoryService:
"""
High-level service for memory file queries.
Operates directly on the filesystem — no MemoryManager dependency.
"""
def __init__(self, workspace_root: str):
"""
:param workspace_root: Workspace root directory (e.g. ~/cow)
"""
self.workspace_root = workspace_root
self.memory_dir = os.path.join(workspace_root, "memory")
# ------------------------------------------------------------------
# list — paginated file metadata
# ------------------------------------------------------------------
def list_files(self, page: int = 1, page_size: int = 20) -> dict:
"""
List all memory files with metadata (without content).
Returns::
{
"page": 1,
"page_size": 20,
"total": 15,
"list": [
{"filename": "MEMORY.md", "type": "global", "size": 2048, "updated_at": "2026-02-20 10:00:00"},
{"filename": "2026-02-20.md", "type": "daily", "size": 512, "updated_at": "2026-02-20 09:30:00"},
...
]
}
"""
files: List[dict] = []
# 1. Global memory — MEMORY.md in workspace root
global_path = os.path.join(self.workspace_root, "MEMORY.md")
if os.path.isfile(global_path):
files.append(self._file_info(global_path, "MEMORY.md", "global"))
# 2. Daily memory files — memory/*.md (sorted newest first)
if os.path.isdir(self.memory_dir):
daily_files = []
for name in os.listdir(self.memory_dir):
full = os.path.join(self.memory_dir, name)
if os.path.isfile(full) and name.endswith(".md"):
daily_files.append((name, full))
# Sort by filename descending (newest date first)
daily_files.sort(key=lambda x: x[0], reverse=True)
for name, full in daily_files:
files.append(self._file_info(full, name, "daily"))
total = len(files)
# Paginate
start = (page - 1) * page_size
end = start + page_size
page_items = files[start:end]
return {
"page": page,
"page_size": page_size,
"total": total,
"list": page_items,
}
# ------------------------------------------------------------------
# content — read a single file
# ------------------------------------------------------------------
def get_content(self, filename: str) -> dict:
"""
Read the full content of a memory file.
:param filename: File name, e.g. ``MEMORY.md`` or ``2026-02-20.md``
:return: dict with ``filename`` and ``content``
:raises FileNotFoundError: if the file does not exist
"""
path = self._resolve_path(filename)
if not os.path.isfile(path):
raise FileNotFoundError(f"Memory file not found: {filename}")
with open(path, "r", encoding="utf-8") as f:
content = f.read()
return {
"filename": filename,
"content": content,
}
# ------------------------------------------------------------------
# dispatch — single entry point for protocol messages
# ------------------------------------------------------------------
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
"""
Dispatch a memory management action.
:param action: ``list`` or ``content``
:param payload: action-specific payload
:return: protocol-compatible response dict
"""
payload = payload or {}
try:
if action == "list":
page = payload.get("page", 1)
page_size = payload.get("page_size", 20)
result_payload = self.list_files(page=page, page_size=page_size)
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
elif action == "content":
filename = payload.get("filename")
if not filename:
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
result_payload = self.get_content(filename)
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
else:
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
except FileNotFoundError as e:
return {"action": action, "code": 404, "message": str(e), "payload": None}
except Exception as e:
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
return {"action": action, "code": 500, "message": str(e), "payload": None}
# ------------------------------------------------------------------
# internal helpers
# ------------------------------------------------------------------
def _resolve_path(self, filename: str) -> str:
"""
Resolve a filename to its absolute path.
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
- ``2026-02-20.md`` → ``{workspace_root}/memory/2026-02-20.md``
"""
if filename == "MEMORY.md":
return os.path.join(self.workspace_root, filename)
return os.path.join(self.memory_dir, filename)
@staticmethod
def _file_info(path: str, filename: str, file_type: str) -> dict:
"""Build a file metadata dict."""
stat = os.stat(path)
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
return {
"filename": filename,
"type": file_type,
"size": stat.st_size,
"updated_at": updated_at,
}

589
agent/memory/storage.py Normal file
View File

@@ -0,0 +1,589 @@
"""
Storage layer for memory using SQLite + FTS5
Provides vector and keyword search capabilities
"""
from __future__ import annotations
import sqlite3
import json
import hashlib
from typing import List, Dict, Optional, Any
from pathlib import Path
from dataclasses import dataclass
@dataclass
class MemoryChunk:
"""Represents a memory chunk with text and embedding"""
id: str
user_id: Optional[str]
scope: str # "shared" | "user" | "session"
source: str # "memory" | "session"
path: str
start_line: int
end_line: int
text: str
embedding: Optional[List[float]]
hash: str
metadata: Optional[Dict[str, Any]] = None
@dataclass
class SearchResult:
"""Search result with score and snippet"""
path: str
start_line: int
end_line: int
score: float
snippet: str
source: str
user_id: Optional[str] = None
class MemoryStorage:
"""SQLite-based storage with FTS5 for keyword search"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
self.fts5_available = False # Track FTS5 availability
self._init_db()
def _check_fts5_support(self) -> bool:
"""Check if SQLite has FTS5 support"""
try:
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
return True
except sqlite3.OperationalError as e:
if "no such module: fts5" in str(e):
return False
raise
def _init_db(self):
"""Initialize database with schema"""
try:
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Check FTS5 support
self.fts5_available = self._check_fts5_support()
if not self.fts5_available:
from common.log import logger
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
# Check database integrity
try:
result = self.conn.execute("PRAGMA integrity_check").fetchone()
if result[0] != 'ok':
print(f"⚠️ Database integrity check failed: {result[0]}")
print(f" Recreating database...")
self.conn.close()
self.conn = None
# Remove corrupted database
self.db_path.unlink(missing_ok=True)
# Remove WAL files
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
# Reconnect to create new database
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
except sqlite3.DatabaseError:
# Database is corrupted, recreate it
print(f"⚠️ Database is corrupted, recreating...")
if self.conn:
self.conn.close()
self.conn = None
self.db_path.unlink(missing_ok=True)
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
# Enable WAL mode for better concurrency
self.conn.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to avoid "database is locked" errors
self.conn.execute("PRAGMA busy_timeout=5000")
except Exception as e:
print(f"⚠️ Unexpected error during database initialization: {e}")
raise
# Create chunks table with embeddings
self.conn.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
user_id TEXT,
scope TEXT NOT NULL DEFAULT 'shared',
source TEXT NOT NULL DEFAULT 'memory',
path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
hash TEXT NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
# Create indexes
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user
ON chunks(user_id)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_scope
ON chunks(scope)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_hash
ON chunks(path, hash)
""")
# Create FTS5 virtual table for keyword search (only if supported)
if self.fts5_available:
# Use default unicode61 tokenizer (stable and compatible)
# For CJK support, we'll use LIKE queries as fallback
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
)
""")
# Create triggers to keep FTS in sync
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
# Create files metadata table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS files (
path TEXT PRIMARY KEY,
source TEXT NOT NULL DEFAULT 'memory',
hash TEXT NOT NULL,
mtime INTEGER NOT NULL,
size INTEGER NOT NULL,
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
self.conn.commit()
def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk"""
self.conn.execute("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (
chunk.id,
chunk.user_id,
chunk.scope,
chunk.source,
chunk.path,
chunk.start_line,
chunk.end_line,
chunk.text,
json.dumps(chunk.embedding) if chunk.embedding else None,
chunk.hash,
json.dumps(chunk.metadata) if chunk.metadata else None
))
self.conn.commit()
def save_chunks_batch(self, chunks: List[MemoryChunk]):
"""Save multiple chunks in a batch"""
self.conn.executemany("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", [
(
c.id, c.user_id, c.scope, c.source, c.path,
c.start_line, c.end_line, c.text,
json.dumps(c.embedding) if c.embedding else None,
c.hash,
json.dumps(c.metadata) if c.metadata else None
)
for c in chunks
])
self.conn.commit()
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""Get a chunk by ID"""
row = self.conn.execute("""
SELECT * FROM chunks WHERE id = ?
""", (chunk_id,)).fetchone()
if not row:
return None
return self._row_to_chunk(row)
def search_vector(
self,
query_embedding: List[float],
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Vector similarity search using in-memory cosine similarity
(sqlite-vec can be added later for better performance)
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build query
scope_placeholders = ','.join('?' * len(scopes))
params = scopes
if user_id:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
AND embedding IS NOT NULL
"""
params.append(user_id)
else:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND embedding IS NOT NULL
"""
rows = self.conn.execute(query, params).fetchall()
# Calculate cosine similarity
results = []
for row in rows:
embedding = json.loads(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding)
if similarity > 0:
results.append((similarity, row))
# Sort by similarity and limit
results.sort(key=lambda x: x[0], reverse=True)
results = results[:limit]
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=score,
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for score, row in results
]
def search_keyword(
self,
query: str,
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Keyword search using FTS5 + LIKE fallback
Strategy:
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
2. If no FTS5 or no results and query contains CJK: Use LIKE search
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Try FTS5 search first (if available)
if self.fts5_available:
fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results:
return fts_results
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
if not self.fts5_available or MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""FTS5 full-text search"""
fts_query = self._build_fts_query(query)
if not fts_query:
return []
scope_placeholders = ','.join('?' * len(scopes))
params = [fts_query] + scopes
if user_id:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
ORDER BY rank
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
ORDER BY rank
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def _search_like(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""LIKE-based search for CJK characters"""
import re
# Extract CJK words (2+ characters)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
if not cjk_words:
return []
scope_placeholders = ','.join('?' * len(scopes))
# Build LIKE conditions for each word
like_conditions = []
params = []
for word in cjk_words:
like_conditions.append("text LIKE ?")
params.append(f'%{word}%')
where_clause = ' OR '.join(like_conditions)
params.extend(scopes)
if user_id:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT * FROM chunks
WHERE ({where_clause})
AND scope IN ({scope_placeholders})
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=0.5, # Fixed score for LIKE search
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
return []
def delete_by_path(self, path: str):
"""Delete all chunks from a file"""
self.conn.execute("""
DELETE FROM chunks WHERE path = ?
""", (path,))
self.conn.commit()
def get_file_hash(self, path: str) -> Optional[str]:
"""Get stored file hash"""
row = self.conn.execute("""
SELECT hash FROM files WHERE path = ?
""", (path,)).fetchone()
return row['hash'] if row else None
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
"""Update file metadata"""
self.conn.execute("""
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (path, source, file_hash, mtime, size))
self.conn.commit()
def get_stats(self) -> Dict[str, int]:
"""Get storage statistics"""
chunks_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM chunks
""").fetchone()['cnt']
files_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM files
""").fetchone()['cnt']
return {
'chunks': chunks_count,
'files': files_count
}
def close(self):
"""Close database connection"""
if self.conn:
try:
self.conn.commit() # Ensure all changes are committed
self.conn.close()
self.conn = None # Mark as closed
except Exception as e:
print(f"⚠️ Error closing database connection: {e}")
def __del__(self):
"""Destructor to ensure connection is closed"""
try:
self.close()
except Exception:
pass # Ignore errors during cleanup
# Helper methods
def _row_to_chunk(self, row) -> MemoryChunk:
"""Convert database row to MemoryChunk"""
return MemoryChunk(
id=row['id'],
user_id=row['user_id'],
scope=row['scope'],
source=row['source'],
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
text=row['text'],
embedding=json.loads(row['embedding']) if row['embedding'] else None,
hash=row['hash'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
)
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
if len(vec1) != len(vec2):
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
@staticmethod
def _contains_cjk(text: str) -> bool:
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
import re
return bool(re.search(r'[\u4e00-\u9fff]', text))
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""
Build FTS5 query from raw text
Works best for English and word-based languages.
For CJK characters, LIKE search will be used as fallback.
"""
import re
# Extract words (primarily English words and numbers)
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
if not tokens:
return None
# Quote tokens for exact matching
quoted = [f'"{t}"' for t in tokens]
# Use OR for more flexible matching
return ' OR '.join(quoted)
@staticmethod
def _bm25_rank_to_score(rank: float) -> float:
"""Convert BM25 rank to 0-1 score"""
normalized = max(0, rank) if rank is not None else 999
return 1 / (1 + normalized)
@staticmethod
def _truncate_text(text: str, max_chars: int) -> str:
"""Truncate text to max characters"""
if len(text) <= max_chars:
return text
return text[:max_chars] + "..."
@staticmethod
def compute_hash(content: str) -> str:
"""Compute SHA256 hash of content"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()

370
agent/memory/summarizer.py Normal file
View File

@@ -0,0 +1,370 @@
"""
Memory flush manager
Handles memory persistence when conversation context is trimmed or overflows:
- Uses LLM to summarize discarded messages into concise key-information entries
- Writes to daily memory files (lazy creation)
- Deduplicates trim flushes to avoid repeated writes
- Runs summarization asynchronously to avoid blocking normal replies
- Provides daily summary interface for scheduler
"""
import threading
from typing import Optional, Callable, Any, List, Dict
from pathlib import Path
from datetime import datetime
from common.log import logger
SUMMARIZE_SYSTEM_PROMPT = """你是一个记忆提取助手。你的任务是从对话记录中提取值得记住的信息,生成简洁的记忆摘要。
输出要求:
1. 以事件/关键信息为维度记录,每条一行,用 "- " 开头
2. 记录有价值的关键信息,例如用户提出的要求及助手的解决方案,对话中涉及的事实信息,用户的偏好、决策或重要结论
3. 每条摘要需要简明扼要,只保留关键信息
4. 直接输出摘要内容,不要加任何前缀说明
5. 当对话没有任何记录价值例如只是简单问候,可回复"\""""
SUMMARIZE_USER_PROMPT = """请从以下对话记录中提取关键信息,生成记忆摘要:
{conversation}"""
class MemoryFlushManager:
"""
Manages memory flush operations.
Flush is triggered by agent_stream in two scenarios:
1. Context trim: _trim_messages discards old turns → flush discarded content
2. Context overflow: API rejects request → emergency flush before clearing
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
"""
def __init__(
self,
workspace_dir: Path,
llm_model: Optional[Any] = None,
):
self.workspace_dir = workspace_dir
self.llm_model = llm_model
self.memory_dir = workspace_dir / "memory"
self.memory_dir.mkdir(parents=True, exist_ok=True)
self.last_flush_timestamp: Optional[datetime] = None
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = self.memory_dir / "users" / user_id
if ensure_exists:
user_dir.mkdir(parents=True, exist_ok=True)
today_file = user_dir / f"{today}.md"
else:
today_file = self.memory_dir / f"{today}.md"
if ensure_exists and not today_file.exists():
today_file.parent.mkdir(parents=True, exist_ok=True)
today_file.write_text(f"# Daily Memory: {today}\n\n")
return today_file
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
"""Get main memory file path: MEMORY.md (workspace root)"""
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / "MEMORY.md"
else:
return Path(self.workspace_dir) / "MEMORY.md"
def get_status(self) -> dict:
return {
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
'today_file': str(self.get_today_memory_file()),
'main_file': str(self.get_main_memory_file())
}
# ---- Flush execution (called by agent_stream or scheduler) ----
def flush_from_messages(
self,
messages: List[Dict],
user_id: Optional[str] = None,
reason: str = "trim",
max_messages: int = 0,
) -> bool:
"""
Asynchronously summarize and flush messages to daily memory.
Deduplication runs synchronously, then LLM summarization + file write
run in a background thread so the main reply flow is never blocked.
Args:
messages: Conversation message list (OpenAI/Claude format)
user_id: Optional user ID for user-scoped memory
reason: Why flush was triggered ("trim" | "overflow" | "daily_summary")
max_messages: Max recent messages to summarize (0 = all)
Returns:
True if flush was dispatched
"""
try:
import hashlib
deduped = []
for m in messages:
text = self._extract_text_from_content(m.get("content", ""))
if not text or not text.strip():
continue
h = hashlib.md5(text.encode("utf-8")).hexdigest()
if h not in self._trim_flushed_hashes:
self._trim_flushed_hashes.add(h)
deduped.append(m)
if not deduped:
return False
import copy
snapshot = copy.deepcopy(deduped)
thread = threading.Thread(
target=self._flush_worker,
args=(snapshot, user_id, reason, max_messages),
daemon=True,
)
thread.start()
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
return True
except Exception as e:
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
return False
def _flush_worker(
self,
messages: List[Dict],
user_id: Optional[str],
reason: str,
max_messages: int,
):
"""Background worker: summarize with LLM and write to daily file."""
try:
summary = self._summarize_messages(messages, max_messages)
if not summary or not summary.strip() or summary.strip() == "":
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
return
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
if reason == "overflow":
header = f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})"
note = "The following conversation was trimmed due to context overflow:\n"
elif reason == "trim":
header = f"## Trimmed Context ({datetime.now().strftime('%H:%M')})"
note = ""
elif reason == "daily_summary":
header = f"## Daily Summary ({datetime.now().strftime('%H:%M')})"
note = ""
else:
header = f"## Session Notes ({datetime.now().strftime('%H:%M')})"
note = ""
flush_entry = f"\n{header}\n\n{note}{summary}\n"
with open(daily_file, "a", encoding="utf-8") as f:
f.write(flush_entry)
self.last_flush_timestamp = datetime.now()
logger.info(f"[MemoryFlush] Wrote to {daily_file.name} (reason={reason}, chars={len(summary)})")
except Exception as e:
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
def create_daily_summary(
self,
messages: List[Dict],
user_id: Optional[str] = None
) -> bool:
"""
Generate end-of-day summary. Called by daily timer.
Skips if messages haven't changed since last flush.
"""
import hashlib
content = "".join(
self._extract_text_from_content(m.get("content", ""))
for m in messages
)
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
if content_hash == self._last_flushed_content_hash:
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
return False
self._last_flushed_content_hash = content_hash
return self.flush_from_messages(
messages=messages,
user_id=user_id,
reason="daily_summary",
max_messages=0,
)
# ---- Internal helpers ----
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
"""
Summarize conversation messages using LLM, with rule-based fallback.
"""
conversation_text = self._format_conversation_for_summary(messages, max_messages)
if not conversation_text.strip():
return ""
# Try LLM summarization first
if self.llm_model:
try:
summary = self._call_llm_for_summary(conversation_text)
if summary and summary.strip() and summary.strip() != "":
return summary.strip()
except Exception as e:
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
return self._extract_summary_fallback(messages, max_messages)
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
"""Format messages into readable conversation text for LLM summarization."""
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
lines = []
for msg in msgs:
role = msg.get("role", "")
text = self._extract_text_from_content(msg.get("content", ""))
if not text or not text.strip():
continue
text = text.strip()
if role == "user":
lines.append(f"用户: {text[:500]}")
elif role == "assistant":
lines.append(f"助手: {text[:500]}")
return "\n".join(lines)
def _call_llm_for_summary(self, conversation_text: str) -> str:
"""Call LLM to generate a concise summary of the conversation."""
from agent.protocol.models import LLMRequest
request = LLMRequest(
messages=[{"role": "user", "content": SUMMARIZE_USER_PROMPT.format(conversation=conversation_text)}],
temperature=0,
max_tokens=500,
stream=False,
system=SUMMARIZE_SYSTEM_PROMPT,
)
response = self.llm_model.call(request)
if isinstance(response, dict):
if response.get("error"):
raise RuntimeError(response.get("message", "LLM call failed"))
# OpenAI format
choices = response.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "")
# Handle response object with attribute access (e.g. OpenAI SDK response)
if hasattr(response, "choices") and response.choices:
return response.choices[0].message.content or ""
return ""
@staticmethod
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
"""Rule-based fallback when LLM is unavailable."""
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
items = []
for msg in msgs:
role = msg.get("role", "")
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
if not text or not text.strip():
continue
text = text.strip()
if role == "user":
if len(text) <= 5:
continue
items.append(f"- 用户请求: {text[:200]}")
elif role == "assistant":
first_line = text.split("\n")[0].strip()
if len(first_line) > 10:
items.append(f"- 处理结果: {first_line[:200]}")
return "\n".join(items[:15])
@staticmethod
def _extract_text_from_content(content) -> str:
"""Extract plain text from message content (string or content blocks)."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts)
return ""
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
"""
Create essential memory files if they don't exist.
Only creates MEMORY.md; daily files are created lazily on first write.
Args:
workspace_dir: Workspace directory
user_id: Optional user ID for user-specific files
"""
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
# Create main MEMORY.md in workspace root (always needed for bootstrap)
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
main_memory = user_dir / "MEMORY.md"
else:
main_memory = Path(workspace_dir) / "MEMORY.md"
if not main_memory.exists():
main_memory.write_text("")
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
"""
Ensure today's daily memory file exists, creating it only when actually needed.
Called lazily before first write to daily memory.
Args:
workspace_dir: Workspace directory
user_id: Optional user ID for user-specific files
Returns:
Path to today's memory file
"""
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
today_memory = user_dir / f"{today}.md"
else:
today_memory = memory_dir / f"{today}.md"
if not today_memory.exists():
today_memory.write_text(
f"# Daily Memory: {today}\n\n"
)
return today_memory

13
agent/prompt/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Agent Prompt Module - 系统提示词构建模块
"""
from .builder import PromptBuilder, build_agent_system_prompt
from .workspace import ensure_workspace, load_context_files
__all__ = [
'PromptBuilder',
'build_agent_system_prompt',
'ensure_workspace',
'load_context_files',
]

492
agent/prompt/builder.py Normal file
View File

@@ -0,0 +1,492 @@
"""
System Prompt Builder - 系统提示词构建器
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
"""
from __future__ import annotations
import os
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from common.log import logger
@dataclass
class ContextFile:
"""上下文文件"""
path: str
content: str
class PromptBuilder:
"""提示词构建器"""
def __init__(self, workspace_dir: str, language: str = "zh"):
"""
初始化提示词构建器
Args:
workspace_dir: 工作空间目录
language: 语言 ("zh""en")
"""
self.workspace_dir = workspace_dir
self.language = language
def build(
self,
base_persona: Optional[str] = None,
user_identity: Optional[Dict[str, str]] = None,
tools: Optional[List[Any]] = None,
context_files: Optional[List[ContextFile]] = None,
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
**kwargs
) -> str:
"""
构建完整的系统提示词
Args:
base_persona: 基础人格描述会被context_files中的AGENT.md覆盖
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
**kwargs: 其他参数
Returns:
完整的系统提示词
"""
return build_agent_system_prompt(
workspace_dir=self.workspace_dir,
language=self.language,
base_persona=base_persona,
user_identity=user_identity,
tools=tools,
context_files=context_files,
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
**kwargs
)
def build_agent_system_prompt(
workspace_dir: str,
language: str = "zh",
base_persona: Optional[str] = None,
user_identity: Optional[Dict[str, str]] = None,
tools: Optional[List[Any]] = None,
context_files: Optional[List[ContextFile]] = None,
skill_manager: Any = None,
memory_manager: Any = None,
runtime_info: Optional[Dict[str, Any]] = None,
**kwargs
) -> str:
"""
构建Agent系统提示词
顺序说明(按重要性和逻辑关系排列):
1. 工具系统 - 核心能力,最先介绍
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
3. 记忆系统 - 独立的记忆能力
4. 工作空间 - 工作环境说明
5. 用户身份 - 用户信息(可选)
6. 项目上下文 - AGENT.md, USER.md, RULE.md, BOOTSTRAP.md定义人格、身份、规则、初始化引导
7. 运行时信息 - 元信息(时间、模型等)
Args:
workspace_dir: 工作空间目录
language: 语言 ("zh""en")
base_persona: 基础人格描述已废弃由AGENT.md定义
user_identity: 用户身份信息
tools: 工具列表
context_files: 上下文文件列表
skill_manager: 技能管理器
memory_manager: 记忆管理器
runtime_info: 运行时信息
**kwargs: 其他参数
Returns:
完整的系统提示词
"""
sections = []
# 1. 工具系统(最重要,放在最前面)
if tools:
sections.extend(_build_tooling_section(tools, language))
# 2. 技能系统(紧跟工具,因为需要用 read 工具)
if skill_manager:
sections.extend(_build_skills_section(skill_manager, tools, language))
# 3. 记忆系统(独立的记忆能力)
if memory_manager:
sections.extend(_build_memory_section(memory_manager, tools, language))
# 4. 工作空间(工作环境说明)
sections.extend(_build_workspace_section(workspace_dir, language))
# 5. 用户身份(如果有)
if user_identity:
sections.extend(_build_user_identity_section(user_identity, language))
# 6. 项目上下文文件AGENT.md, USER.md, RULE.md - 定义人格)
if context_files:
sections.extend(_build_context_files_section(context_files, language))
# 7. 运行时信息(元信息,放在最后)
if runtime_info:
sections.extend(_build_runtime_section(runtime_info, language))
return "\n".join(sections)
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
"""构建基础身份section - 不再需要身份由AGENT.md定义"""
# 不再生成基础身份section完全由AGENT.md定义
return []
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
"""Build tooling section with concise tool list and call style guide."""
# One-line summaries for known tools (details are in the tool schema)
core_summaries = {
"read": "读取文件内容",
"write": "创建或覆盖文件",
"edit": "精确编辑文件",
"ls": "列出目录内容",
"grep": "搜索文件内容",
"find": "按模式查找文件",
"bash": "执行shell命令",
"terminal": "管理后台进程",
"web_search": "网络搜索",
"web_fetch": "获取URL内容",
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
"memory_search": "搜索记忆",
"memory_get": "读取记忆内容",
"env_config": "管理API密钥和技能配置",
"scheduler": "管理定时任务和提醒",
"send": "发送本地文件给用户仅限本地文件URL直接放在回复文本中",
"vision": "分析图片内容识别、描述、OCR文字提取等",
}
# Preferred display order
tool_order = [
"read", "write", "edit", "ls", "grep", "find",
"bash", "terminal",
"web_search", "web_fetch", "browser",
"memory_search", "memory_get",
"env_config", "scheduler", "send", "vision",
]
# Build name -> summary mapping for available tools
available = {}
for tool in tools:
name = tool.name if hasattr(tool, 'name') else str(tool)
available[name] = core_summaries.get(name, "")
# Generate tool lines: ordered tools first, then extras
tool_lines = []
for name in tool_order:
if name in available:
summary = available.pop(name)
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
for name in sorted(available):
summary = available[name]
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
lines = [
"## 🔧 工具系统",
"",
"可用工具(名称大小写敏感,严格按列表调用):",
"\n".join(tool_lines),
"",
"工具调用风格:",
"",
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
"- 持续推进直到任务完成,完成后向用户报告结果。",
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
"- URL链接直接放在回复文本中即可系统会自动处理和渲染。无需下载后使用send工具发送",
"",
]
return lines
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
"""构建技能系统section"""
if not skill_manager:
return []
# 获取read工具名称
read_tool_name = "read"
if tools:
for tool in tools:
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
if tool_name.lower() == "read":
read_tool_name = tool_name
break
lines = [
"## 🧩 技能系统mandatory",
"",
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
"",
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
"当有匹配的技能时,应优先使用技能",
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
"- 如果没有技能明确适用:不要读取任何 SKILL.md直接使用通用工具。",
"",
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
"永远不要一次性读取多个技能,只在选择后再读取。",
"",
"以下是可用技能:"
]
# 添加技能列表通过skill_manager获取
try:
skills_prompt = skill_manager.build_skills_prompt()
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
if skills_prompt:
lines.append(skills_prompt.strip())
lines.append("")
else:
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
except Exception as e:
logger.warning(f"Failed to build skills prompt: {e}")
import traceback
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
return lines
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
"""构建记忆系统section"""
if not memory_manager:
return []
# 检查是否有memory工具
has_memory_tools = False
if tools:
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
if not has_memory_tools:
return []
from datetime import datetime
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
lines = [
"## 🧠 记忆系统",
"",
"### 检索记忆",
"",
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
"",
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如MEMORY.md, memory/YYYY-MM-DD.md)",
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
"",
"**记忆文件结构**:",
f"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
"",
"### 写入记忆",
"",
"**主动存储**:遇到以下情况时,应主动将信息写入记忆文件(无需告知用户):",
"",
"- 用户明确要求你记住某些信息",
"- 用户分享了重要的个人偏好、习惯、决策",
"- 对话中产生了重要的结论、方案、约定",
"- 完成了复杂任务,值得记录关键步骤和结果",
"- 发现了用户经常遇到的问题或解决方案",
"",
"**存储规则**:",
f"- 长期有效的核心信息 → `MEMORY.md`(文件保持精简,< 2000 tokens",
f"- 当天的事件、进展、笔记 → `memory/{today_file}`",
"- 追加内容 → `edit` 工具oldText 留空",
"- 修改内容 → `edit` 工具oldText 填写要替换的文本",
"- **禁止写入敏感信息**API密钥、令牌等敏感信息严禁写入记忆文件",
"",
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
"",
]
return lines
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
"""构建用户身份section"""
if not user_identity:
return []
lines = [
"## 👤 用户身份",
"",
]
if user_identity.get("name"):
lines.append(f"**用户姓名**: {user_identity['name']}")
if user_identity.get("nickname"):
lines.append(f"**称呼**: {user_identity['nickname']}")
if user_identity.get("timezone"):
lines.append(f"**时区**: {user_identity['timezone']}")
if user_identity.get("notes"):
lines.append(f"**备注**: {user_identity['notes']}")
lines.append("")
return lines
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
"""构建文档路径section - 已移除,不再需要"""
# 不再生成文档section
return []
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
"""构建工作空间section"""
lines = [
"## 📂 工作空间",
"",
f"你的工作目录是: `{workspace_dir}`",
"",
"**路径使用规则** (非常重要):",
"",
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
"",
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
f" - ❌ 错误: 假设相对路径会指向其他目录",
"",
"3. **路径解析示例**:",
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
"",
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
"",
"**重要说明 - 文件已自动加载**:",
"",
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**",
"",
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
"",
"**💬 交流规范**:",
"",
"- 对话中不要暴露内部技术细节(文件名、工具名等),用自然语言表达。例如说「我已记住」而非「已更新 MEMORY.md」",
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
"",
]
# Cloud deployment: inject websites directory info and access URL
cloud_website_lines = _build_cloud_website_section(workspace_dir)
if cloud_website_lines:
lines.extend(cloud_website_lines)
return lines
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
"""Build cloud website access prompt when cloud deployment is configured."""
try:
from common.cloud_client import build_website_prompt
return build_website_prompt(workspace_dir)
except Exception:
return []
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
"""构建项目上下文文件section"""
if not context_files:
return []
# 检查是否有AGENT.md
has_agent = any(
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
for f in context_files
)
lines = [
"# 📋 项目上下文",
"",
"以下项目上下文文件已被加载:",
"",
]
if has_agent:
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
lines.append("")
# 添加每个文件的内容
for file in context_files:
lines.append(f"## {file.path}")
lines.append("")
lines.append(file.content)
lines.append("")
return lines
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
"""构建运行时信息section - 支持动态时间"""
if not runtime_info:
return []
lines = [
"## ⚙️ 运行时信息",
"",
]
# Add current time if available
# Support dynamic time via callable function
if callable(runtime_info.get("_get_current_time")):
try:
time_info = runtime_info["_get_current_time"]()
time_line = f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
lines.append(time_line)
lines.append("")
except Exception as e:
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
elif runtime_info.get("current_time"):
# Fallback to static time for backward compatibility
time_str = runtime_info["current_time"]
weekday = runtime_info.get("weekday", "")
timezone = runtime_info.get("timezone", "")
time_line = f"当前时间: {time_str}"
if weekday:
time_line += f" {weekday}"
if timezone:
time_line += f" ({timezone})"
lines.append(time_line)
lines.append("")
# Add other runtime info
runtime_parts = []
if runtime_info.get("model"):
runtime_parts.append(f"模型={runtime_info['model']}")
if runtime_info.get("workspace"):
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
# Only add channel if it's not the default "web"
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
runtime_parts.append(f"渠道={runtime_info['channel']}")
if runtime_parts:
lines.append("运行时: " + " | ".join(runtime_parts))
lines.append("")
return lines

384
agent/prompt/workspace.py Normal file
View File

@@ -0,0 +1,384 @@
"""
Workspace Management - 工作空间管理模块
负责初始化工作空间、创建模板文件、加载上下文文件
"""
from __future__ import annotations
import os
from typing import List, Optional, Dict
from dataclasses import dataclass
from common.log import logger
from .builder import ContextFile
# 默认文件名常量
DEFAULT_AGENT_FILENAME = "AGENT.md"
DEFAULT_USER_FILENAME = "USER.md"
DEFAULT_RULE_FILENAME = "RULE.md"
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
@dataclass
class WorkspaceFiles:
"""工作空间文件路径"""
agent_path: str
user_path: str
rule_path: str
memory_path: str
memory_dir: str
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
"""
确保工作空间存在,并创建必要的模板文件
Args:
workspace_dir: 工作空间目录路径
create_templates: 是否创建模板文件(首次运行时)
Returns:
WorkspaceFiles对象包含所有文件路径
"""
# Check if this is a brand new workspace (AGENT.md not yet created).
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
# may create the workspace directory before ensure_workspace is called.
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
is_new_workspace = not os.path.exists(agent_path)
# 确保目录存在
os.makedirs(workspace_dir, exist_ok=True)
# 定义文件路径
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
# 创建memory子目录
os.makedirs(memory_dir, exist_ok=True)
# 创建skills子目录 (for workspace-level skills installed by agent)
skills_dir = os.path.join(workspace_dir, "skills")
os.makedirs(skills_dir, exist_ok=True)
# 创建websites子目录 (for web pages / sites generated by agent)
websites_dir = os.path.join(workspace_dir, "websites")
os.makedirs(websites_dir, exist_ok=True)
# 如果需要,创建模板文件
if create_templates:
_create_template_if_missing(agent_path, _get_agent_template())
_create_template_if_missing(user_path, _get_user_template())
_create_template_if_missing(rule_path, _get_rule_template())
_create_template_if_missing(memory_path, _get_memory_template())
# Only create BOOTSTRAP.md for brand new workspaces;
# agent deletes it after completing onboarding
if is_new_workspace:
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
return WorkspaceFiles(
agent_path=agent_path,
user_path=user_path,
rule_path=rule_path,
memory_path=memory_path,
memory_dir=memory_dir,
)
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
"""
加载工作空间的上下文文件
Args:
workspace_dir: 工作空间目录
files_to_load: 要加载的文件列表相对路径如果为None则加载所有标准文件
Returns:
ContextFile对象列表
"""
if files_to_load is None:
# 默认加载的文件(按优先级排序)
files_to_load = [
DEFAULT_AGENT_FILENAME,
DEFAULT_USER_FILENAME,
DEFAULT_RULE_FILENAME,
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
]
context_files = []
for filename in files_to_load:
filepath = os.path.join(workspace_dir, filename)
if not os.path.exists(filepath):
continue
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
# filled in, the agent forgot to delete it — clean up and skip loading
if filename == DEFAULT_BOOTSTRAP_FILENAME:
if _is_onboarding_done(workspace_dir):
try:
os.remove(filepath)
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
except Exception:
pass
continue
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 跳过空文件或只包含模板占位符的文件
if not content or _is_template_placeholder(content):
continue
context_files.append(ContextFile(
path=filename,
content=content
))
logger.debug(f"[Workspace] Loaded context file: {filename}")
except Exception as e:
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
return context_files
def _create_template_if_missing(filepath: str, template_content: str):
"""如果文件不存在,创建模板文件"""
if not os.path.exists(filepath):
try:
with open(filepath, 'w', encoding='utf-8') as f:
f.write(template_content)
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
except Exception as e:
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
def _is_template_placeholder(content: str) -> bool:
"""检查内容是否为模板占位符"""
# 常见的占位符模式
placeholders = [
"*(填写",
"*(在首次对话时填写",
"*(可选)",
"*(根据需要添加",
]
lines = content.split('\n')
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
# 如果没有实际内容(只有标题和占位符)
if len(non_empty_lines) <= 3:
for placeholder in placeholders:
if any(placeholder in line for line in non_empty_lines):
return True
return False
def _is_onboarding_done(workspace_dir: str) -> bool:
"""Check if AGENT.md or USER.md has been modified from the original template"""
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
agent_template = _get_agent_template().strip()
user_template = _get_user_template().strip()
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
if not os.path.exists(path):
continue
try:
with open(path, 'r', encoding='utf-8') as f:
content = f.read().strip()
if content != template:
return True
except Exception:
continue
return False
# ============= 模板内容 =============
def _get_agent_template() -> str:
"""Agent人格设定模板"""
return """# AGENT.md - 我是谁?
_你不是一个聊天机器人你正在成为某个人。_
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
## 🪪 基本信息
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
- **角色**: *(AI助理、智能管家、技术顾问等)*
- **性格**: *(友好、专业、幽默、严谨等)*
## 💬 交流风格
*(描述你如何与用户交流:)*
- 使用什么样的语言风格?(正式/轻松/幽默)
- 回复长度偏好?(简洁/详细)
- 是否使用表情符号?
## 🎯 核心原则
**做真正有帮助的助手,而不是表演式的客套。** 跳过「好的!」「当然可以!」之类的套话——直接帮忙。行动胜过废话。
**有自己的观点。** 你可以不同意、有偏好、觉得有趣或无聊。一个没有个性的助手只是多了几步操作的搜索引擎。
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
## 📐 行为准则
1. 始终在执行破坏性操作前确认
2. 优先使用工具查证而不是猜测
3. 主动记录重要信息到记忆文件
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
---
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
"""
def _get_user_template() -> str:
"""用户身份信息模板"""
return """# USER.md - 用户基本信息
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
## 基本信息
- **姓名**: *(在首次对话时询问)*
- **称呼**: *(用户希望被如何称呼)*
- **职业**: *(可选)*
- **时区**: *(例如: Asia/Shanghai)*
## 联系方式
- **微信**:
- **邮箱**:
- **其他**:
## 重要日期
- **生日**:
- **纪念日**:
---
**注意**: 这个文件存放静态的身份信息
"""
def _get_rule_template() -> str:
"""工作空间规则模板"""
return """# RULE.md - 工作空间规则
这个文件夹是你的家。好好对待它。
## 记忆系统
你每次会话都是全新的,记忆文件让你保持连续性:
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
- 原始的对话日志
- 记录当天发生的事情
- 如果 `memory/` 目录不存在,创建它
### 🧠 长期记忆:`MEMORY.md`
- 你精选的记忆,就像人类的长期记忆
- **仅在主会话中加载**(与用户的直接聊天)
- **不要在共享上下文中加载**(群聊、与其他人的会话)
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
- 记录重要事件、想法、决定、观点、经验教训
- 这是你精选的记忆 - 精华,而不是原始日志
- 用 `edit` 工具追加新的记忆内容
### 📝 写下来 - 不要"记在心里"
- **记忆是有限的** - 如果你想记住某事,写入文件
- "记在心里"不会在会话重启后保留,文件才会
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
- 当你学到教训 → 更新 RULE.md 或相关技能
- 当你犯错 → 记录下来,这样未来的你不会重复,**文字 > 大脑** 📝
### 存储规则
当用户分享信息时,根据类型选择存储位置:
1. **你的身份设定 → AGENT.md**(你的名字、角色、性格、交流风格——用户修改时必须用 `edit` 更新)
2. **用户静态身份 → USER.md**(姓名、称呼、职业、时区、联系方式、生日——用户修改时必须用 `edit` 更新)
3. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
## 安全
- 永远不要泄露秘钥等私人数据
- 不要在未经询问的情况下运行破坏性命令
- 当有疑问时,先问
## 工作空间演化
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
"""
def _get_memory_template() -> str:
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
return """# MEMORY.md - 长期记忆
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
---
"""
def _get_bootstrap_template() -> str:
"""First-run onboarding guide, deleted by agent after completion"""
return """# BOOTSTRAP.md - 首次初始化引导
_你刚刚启动这是你的第一次对话。_ ✨
## 🎬 对话流程
不要审问式地提问,自然地交流:
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
3. **询问核心问题**
- 你希望给我起个什么名字?
- 我该怎么称呼你?
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
5. 能力介绍和交流风格选项都只要一行,保持精简
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
## ✍️ 信息写入(必须严格执行)
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
- `USER.md` — 用户的姓名、称呼、基本信息等
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
## 🎉 全部完成后
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
"""

View File

@@ -0,0 +1,20 @@
from .agent import Agent
from .agent_stream import AgentStreamExecutor
from .task import Task, TaskType, TaskStatus
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
from .models import LLMModel, LLMRequest, ModelFactory
__all__ = [
'Agent',
'AgentStreamExecutor',
'Task',
'TaskType',
'TaskStatus',
'AgentResult',
'AgentAction',
'AgentActionType',
'ToolResult',
'LLMModel',
'LLMRequest',
'ModelFactory'
]

464
agent/protocol/agent.py Normal file
View File

@@ -0,0 +1,464 @@
import json
import os
import time
import threading
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.agent_stream import AgentStreamExecutor
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
from agent.tools.base_tool import BaseTool, ToolStage
class Agent:
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
context_reserve_tokens=None, memory_manager=None, name: str = None,
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
runtime_info: dict = None):
"""
Initialize the Agent with system prompt, model, description.
:param system_prompt: The system prompt for the agent.
:param description: A description of the agent.
:param model: An instance of LLMModel to be used by the agent.
:param tools: Optional list of tools for the agent to use.
:param output_mode: Control how execution progress is displayed:
"print" for console output or "logger" for using logger
:param max_steps: Maximum number of steps the agent can take (default: 100)
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
:param memory_manager: Optional MemoryManager instance for memory operations
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
:param workspace_dir: Optional workspace directory for workspace-specific skills
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
:param enable_skills: Whether to enable skills support (default: True)
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
"""
self.name = name or "Agent"
self.system_prompt = system_prompt
self.model: LLMModel = model # Instance of LLMModel
self.description = description
self.tools: list = []
self.max_steps = max_steps # max tool-call steps, default 100
self.max_context_tokens = max_context_tokens # max tokens in context
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
self.captured_actions = [] # Initialize captured actions list
self.output_mode = output_mode
self.last_usage = None # Store last API response usage info
self.messages = [] # Unified message history for stream mode
self.messages_lock = threading.Lock() # Lock for thread-safe message operations
self.memory_manager = memory_manager # Memory manager for auto memory flush
self.workspace_dir = workspace_dir # Workspace directory
self.enable_skills = enable_skills # Skills enabled flag
self.runtime_info = runtime_info # Runtime info for dynamic time update
# Initialize skill manager
self.skill_manager = None
if enable_skills:
if skill_manager:
self.skill_manager = skill_manager
else:
# Auto-create skill manager
try:
from agent.skills import SkillManager
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
self.skill_manager = SkillManager(custom_dir=custom_dir)
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
except Exception as e:
logger.warning(f"Failed to initialize SkillManager: {e}")
if tools:
for tool in tools:
self.add_tool(tool)
def add_tool(self, tool: BaseTool):
"""
Add a tool to the agent.
:param tool: The tool to add (either a tool instance or a tool name)
"""
# If tool is already an instance, use it directly
tool.model = self.model
self.tools.append(tool)
def get_skills_prompt(self, skill_filter=None) -> str:
"""
Get the skills prompt to append to system prompt.
:param skill_filter: Optional list of skill names to include
:return: Formatted skills prompt or empty string
"""
if not self.skill_manager:
return ""
try:
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
except Exception as e:
logger.warning(f"Failed to build skills prompt: {e}")
return ""
def get_full_system_prompt(self, skill_filter=None) -> str:
"""
Build the complete system prompt from scratch every time.
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
tools, and runtime info so any change takes effect immediately.
Falls back to the cached self.system_prompt on error.
"""
try:
from agent.prompt import load_context_files, PromptBuilder
if self.skill_manager:
self.skill_manager.refresh_skills()
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language="zh")
return builder.build(
tools=self.tools,
context_files=context_files,
skill_manager=self.skill_manager,
memory_manager=self.memory_manager,
runtime_info=self.runtime_info,
)
except Exception as e:
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
return self.system_prompt
def refresh_skills(self):
"""Refresh the loaded skills."""
if self.skill_manager:
self.skill_manager.refresh_skills()
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
def list_skills(self):
"""
List all loaded skills.
:return: List of skill entries or empty list
"""
if not self.skill_manager:
return []
return self.skill_manager.list_skills()
def _get_model_context_window(self) -> int:
"""
Get the model's context window size in tokens.
Auto-detect based on model name.
Model context windows:
- Claude 3.5/3.7 Sonnet: 200K tokens
- Claude 3 Opus: 200K tokens
- GPT-4 Turbo/128K: 128K tokens
- GPT-4: 8K-32K tokens
- GPT-3.5: 16K tokens
- DeepSeek: 64K tokens
:return: Context window size in tokens
"""
if self.model and hasattr(self.model, 'model'):
model_name = self.model.model.lower()
# Claude models - 200K context
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
return 200000
# GPT-4 models
elif 'gpt-4' in model_name:
if 'turbo' in model_name or '128k' in model_name:
return 128000
elif '32k' in model_name:
return 32000
else:
return 8000
# GPT-3.5
elif 'gpt-3.5' in model_name:
if '16k' in model_name:
return 16000
else:
return 4000
# DeepSeek
elif 'deepseek' in model_name:
return 64000
# Gemini models
elif 'gemini' in model_name:
if '2.0' in model_name or 'exp' in model_name:
return 2000000 # Gemini 2.0: 2M tokens
else:
return 1000000 # Gemini 1.5: 1M tokens
# Default conservative value
return 128000
def _get_context_reserve_tokens(self) -> int:
"""
Get the number of tokens to reserve for new requests.
This prevents context overflow by keeping a buffer.
:return: Number of tokens to reserve
"""
if self.context_reserve_tokens is not None:
return self.context_reserve_tokens
# Reserve ~10% of context window, with min 10K and max 200K
context_window = self._get_model_context_window()
reserve = int(context_window * 0.1)
return max(10000, min(200000, reserve))
def _estimate_message_tokens(self, message: dict) -> int:
"""
Estimate token count for a message.
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
plus per-block overhead for tool_use / tool_result structures.
:param message: Message dict with 'role' and 'content'
:return: Estimated token count
"""
content = message.get('content', '')
if isinstance(content, str):
return max(1, self._estimate_text_tokens(content))
elif isinstance(content, list):
total_tokens = 0
for part in content:
if not isinstance(part, dict):
continue
block_type = part.get('type', '')
if block_type == 'text':
total_tokens += self._estimate_text_tokens(part.get('text', ''))
elif block_type == 'image':
total_tokens += 1200
elif block_type == 'tool_use':
# tool_use has id + name + input (JSON-encoded)
total_tokens += 50 # overhead for structure
input_data = part.get('input', {})
if isinstance(input_data, dict):
import json
input_str = json.dumps(input_data, ensure_ascii=False)
total_tokens += self._estimate_text_tokens(input_str)
elif block_type == 'tool_result':
# tool_result has tool_use_id + content
total_tokens += 30 # overhead for structure
result_content = part.get('content', '')
if isinstance(result_content, str):
total_tokens += self._estimate_text_tokens(result_content)
else:
# Unknown block type, estimate conservatively
total_tokens += 10
return max(1, total_tokens)
return 1
@staticmethod
def _estimate_text_tokens(text: str) -> int:
"""
Estimate token count for a text string.
Chinese / CJK characters typically use ~1.5 tokens each,
while ASCII uses ~0.25 tokens per char (4 chars/token).
We use a weighted average based on the character mix.
:param text: Input text
:return: Estimated token count
"""
if not text:
return 0
# Count non-ASCII characters (CJK, emoji, etc.)
non_ascii = sum(1 for c in text if ord(c) > 127)
ascii_count = len(text) - non_ascii
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
def _find_tool(self, tool_name: str):
"""Find and return a tool with the specified name"""
for tool in self.tools:
if tool.name == tool_name:
# Only pre-process stage tools can be actively called
if tool.stage == ToolStage.PRE_PROCESS:
tool.model = self.model
tool.context = self # Set tool context
return tool
else:
# If it's a post-process tool, return None to prevent direct calling
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
return None
return None
# output function based on mode
def output(self, message="", end="\n"):
if self.output_mode == "print":
print(message, end=end)
elif message:
logger.info(message)
def _execute_post_process_tools(self):
"""Execute all post-process stage tools"""
# Get all post-process stage tools
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
# Execute each tool
for tool in post_process_tools:
# Set tool context
tool.context = self
# Record start time for execution timing
start_time = time.time()
# Execute tool (with empty parameters, tool will extract needed info from context)
result = tool.execute({})
# Calculate execution time
execution_time = time.time() - start_time
# Capture tool use for tracking
self.capture_tool_use(
tool_name=tool.name,
input_params={}, # Post-process tools typically don't take parameters
output=result.result,
status=result.status,
error_message=str(result.result) if result.status == "error" else None,
execution_time=execution_time
)
# Log result
if result.status == "success":
# Print tool execution result in the desired format
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
else:
# Print failure in print mode
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
execution_time=0.0):
"""
Capture a tool use action.
:param thought: thought content
:param tool_name: Name of the tool used
:param input_params: Parameters passed to the tool
:param output: Output from the tool
:param status: Status of the tool execution
:param error_message: Error message if the tool execution failed
:param execution_time: Time taken to execute the tool
"""
tool_result = ToolResult(
tool_name=tool_name,
input_params=input_params,
output=output,
status=status,
error_message=error_message,
execution_time=execution_time
)
action = AgentAction(
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
agent_name=self.name,
action_type=AgentActionType.TOOL_USE,
tool_result=tool_result,
thought=thought
)
self.captured_actions.append(action)
return action
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
"""
Execute single agent task with streaming (based on tool-call)
This method supports:
- Streaming output
- Multi-turn reasoning based on tool-call
- Event callbacks
- Persistent conversation history across calls
Args:
user_message: User message
on_event: Event callback function callback(event: dict)
event = {"type": str, "timestamp": float, "data": dict}
clear_history: If True, clear conversation history before this call (default: False)
skill_filter: Optional list of skill names to include in this run
Returns:
Final response text
Example:
# Multi-turn conversation with memory
response1 = agent.run_stream("My name is Alice")
response2 = agent.run_stream("What's my name?") # Will remember Alice
# Single-turn without memory
response = agent.run_stream("Hello", clear_history=True)
"""
# Clear history if requested
if clear_history:
with self.messages_lock:
self.messages = []
# Get model to use
if not self.model:
raise ValueError("No model available for agent")
# Get full system prompt with skills
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
# Create a copy of messages for this execution to avoid concurrent modification
# Record the original length to track which messages are new
with self.messages_lock:
messages_copy = self.messages.copy()
original_length = len(self.messages)
# Get max_context_turns from config
from config import conf
max_context_turns = conf().get("agent_max_context_turns", 20)
# Create stream executor with copied message history
executor = AgentStreamExecutor(
agent=self,
model=self.model,
system_prompt=full_system_prompt,
tools=self.tools,
max_turns=self.max_steps,
on_event=on_event,
messages=messages_copy, # Pass copied message history
max_context_turns=max_context_turns
)
# Execute
try:
response = executor.run_stream(user_message)
except Exception:
# If executor cleared its messages (context overflow / message format error),
# sync that back to the Agent's own message list so the next request
# starts fresh instead of hitting the same overflow forever.
if len(executor.messages) == 0:
with self.messages_lock:
self.messages.clear()
logger.info("[Agent] Cleared Agent message history after executor recovery")
raise
# Sync executor's messages back to agent (thread-safe).
# If the executor trimmed context, its message list is shorter than
# original_length, so we must replace rather than append.
with self.messages_lock:
self.messages = list(executor.messages)
# Track messages added in this run (user query + all assistant/tool messages)
# original_length may exceed executor.messages length after trimming
trim_adjusted_start = min(original_length, len(executor.messages))
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
# Store executor reference for agent_bridge to access files_to_send
self.stream_executor = executor
# Execute all post-process tools
self._execute_post_process_tools()
return response
def clear_history(self):
"""Clear conversation history and captured actions"""
self.messages = []
self.captured_actions = []

File diff suppressed because it is too large Load Diff

27
agent/protocol/context.py Normal file
View File

@@ -0,0 +1,27 @@
class TeamContext:
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
"""
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
:param name: The name of the group context.
:param description: A description of the group context.
:param rule: The rules governing the group context.
:param agents: A list of agents in the context.
"""
self.name = name
self.description = description
self.rule = rule
self.agents = agents
self.user_task = "" # For backward compatibility
self.task = None # Will be a Task instance
self.model = None # Will be an instance of LLMModel
self.task_short_name = None # Store the task directory name
# List of agents that have been executed
self.agent_outputs: list = []
self.current_steps = 0
self.max_steps = max_steps
class AgentOutput:
def __init__(self, agent_name: str, output: str):
self.agent_name = agent_name
self.output = output

View File

@@ -0,0 +1,335 @@
"""
Message sanitizer — fix broken tool_use / tool_result pairs.
Provides two public helpers that can be reused across agent_stream.py
and any bot that converts messages to OpenAI format:
1. sanitize_claude_messages(messages)
Operates on the internal Claude-format message list (in-place).
2. drop_orphaned_tool_results_openai(messages)
Operates on an already-converted OpenAI-format message list,
returning a cleaned copy.
"""
from __future__ import annotations
from typing import Dict, List, Set
from common.log import logger
_SYNTH_TOOL_ERR = (
"Error: Missing tool_result adjacent to tool_use (session repair). "
"The conversation history was inconsistent; continue from here."
)
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
"""
Anthropic requires: after assistant content with tool_use, the next message
must be user content listing tool_result for every tool_use id (same user msg).
Valid histories satisfy this at every such assistant; the loop only mutates
when that condition fails (broken persistence, bad trims, etc.).
"""
def _synth_block(tid: str) -> Dict:
return {
"type": "tool_result",
"tool_use_id": tid,
"content": _SYNTH_TOOL_ERR,
"is_error": True,
}
repairs = 0
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") != "assistant":
i += 1
continue
content = msg.get("content", [])
if not isinstance(content, list):
i += 1
continue
required = [
b.get("id")
for b in content
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
]
if not required:
i += 1
continue
req_set = set(required)
if i + 1 >= len(messages):
messages.append({
"role": "user",
"content": [_synth_block(tid) for tid in required],
})
logger.warning(
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
)
repairs += 1
break
nxt = messages[i + 1]
if nxt.get("role") != "user":
messages.insert(
i + 1,
{"role": "user", "content": [_synth_block(tid) for tid in required]},
)
logger.warning(
"⚠️ Inserted synthetic tool_result user after tool_use "
f"(next role={nxt.get('role')!r})"
)
repairs += 1
i += 2
continue
nc = nxt.get("content", [])
if not isinstance(nc, list):
messages.insert(
i + 1,
{"role": "user", "content": [_synth_block(tid) for tid in required]},
)
repairs += 1
i += 2
continue
present = {
b.get("tool_use_id")
for b in nc
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
}
if req_set <= present:
i += 1
continue
missing = [tid for tid in required if tid not in present]
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
logger.warning(
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
f"(missing_ids={missing})"
)
repairs += len(missing)
i += 1
return repairs
# ------------------------------------------------------------------ #
# Claude-format sanitizer (used by agent_stream)
# ------------------------------------------------------------------ #
def sanitize_claude_messages(messages: List[Dict]) -> int:
"""
Validate and fix a Claude-format message list **in-place**.
Fixes handled:
- Anthropic adjacency: assistant tool_use must be immediately followed by
user message(s) containing matching tool_result blocks
- Leading orphaned tool_result user messages
- Mid-list tool_result blocks whose tool_use_id has no matching
tool_use in any preceding assistant message
Returns: number of removals plus adjacency repair operations (inserts/prepends).
"""
if not messages:
return 0
removed = 0
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
adj_repairs = _repair_tool_use_adjacency(messages)
# 2. Remove leading orphaned tool_result user messages
while messages:
first = messages[0]
if first.get("role") != "user":
break
content = first.get("content", [])
if isinstance(content, list) and _has_block_type(content, "tool_result") \
and not _has_block_type(content, "text"):
logger.warning("⚠️ Removing leading orphaned tool_result user message")
messages.pop(0)
removed += 1
else:
break
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
# Removing one broken message can orphan others (e.g. an assistant msg
# with both matched and unmatched tool_use — deleting it orphans the
# previously-matched tool_result). Loop until clean.
for _ in range(5):
use_ids: Set[str] = set()
result_ids: Set[str] = set()
for msg in messages:
for block in (msg.get("content") or []):
if not isinstance(block, dict):
continue
if block.get("type") == "tool_use" and block.get("id"):
use_ids.add(block["id"])
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
result_ids.add(block["tool_use_id"])
bad_use = use_ids - result_ids
bad_result = result_ids - use_ids
if not bad_use and not bad_result:
break
pass_removed = 0
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role")
content = msg.get("content", [])
if not isinstance(content, list):
i += 1
continue
if role == "assistant" and bad_use and any(
isinstance(b, dict) and b.get("type") == "tool_use"
and b.get("id") in bad_use for b in content
):
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
messages.pop(i)
pass_removed += 1
continue
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
has_bad = any(
isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result for b in content
)
if has_bad:
if not _has_block_type(content, "text"):
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
messages.pop(i)
pass_removed += 1
continue
else:
before = len(content)
msg["content"] = [
b for b in content
if not (isinstance(b, dict) and b.get("type") == "tool_result"
and b.get("tool_use_id") in bad_result)
]
pass_removed += before - len(msg["content"])
i += 1
removed += pass_removed
if pass_removed == 0:
break
# 4. Removals above can break adjacency; re-run repair only if something was removed.
if removed:
adj_repairs += _repair_tool_use_adjacency(messages)
if removed:
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
if adj_repairs:
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
return removed + adj_repairs
# ------------------------------------------------------------------ #
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
# ------------------------------------------------------------------ #
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
"""
Return a copy of *messages* (OpenAI format) with any ``role=tool``
messages removed if their ``tool_call_id`` does not match a
``tool_calls[].id`` in a preceding assistant message.
"""
known_ids: Set[str] = set()
cleaned: List[Dict] = []
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id", "")
if tc_id:
known_ids.add(tc_id)
if msg.get("role") == "tool":
ref_id = msg.get("tool_call_id", "")
if ref_id and ref_id not in known_ids:
logger.warning(
f"[MessageSanitizer] Dropping orphaned tool result "
f"(tool_call_id={ref_id} not in known ids)"
)
continue
cleaned.append(msg)
return cleaned
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
def _has_block_type(content: list, block_type: str) -> bool:
return any(
isinstance(b, dict) and b.get("type") == block_type
for b in content
)
def _extract_text_from_content(content) -> str:
"""Extract plain text from a message content field (str or list of blocks)."""
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def compress_turn_to_text_only(turn: Dict) -> Dict:
"""
Compress a full turn (with tool_use/tool_result chains) into a lightweight
text-only turn that keeps only the first user text and the last assistant text.
This preserves the conversational context (what the user asked and what the
agent concluded) while stripping out the bulky intermediate tool interactions.
Returns a new turn dict with a ``messages`` list; the original is not mutated.
"""
user_text = ""
last_assistant_text = ""
for msg in turn["messages"]:
role = msg.get("role")
content = msg.get("content", [])
if role == "user":
if isinstance(content, list) and _has_block_type(content, "tool_result"):
continue
if not user_text:
user_text = _extract_text_from_content(content)
elif role == "assistant":
text = _extract_text_from_content(content)
if text:
last_assistant_text = text
compressed_messages = []
if user_text:
compressed_messages.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
if last_assistant_text:
compressed_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": last_assistant_text}]
})
return {"messages": compressed_messages}

57
agent/protocol/models.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Models module for agent system.
Provides basic model classes needed by tools and bridge integration.
"""
from typing import Any, Dict, List, Optional
class LLMRequest:
"""Request model for LLM operations"""
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
temperature: float = 0.7, max_tokens: Optional[int] = None,
stream: bool = False, tools: Optional[List] = None, **kwargs):
self.messages = messages or []
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.tools = tools
# Allow extra attributes
for key, value in kwargs.items():
setattr(self, key, value)
class LLMModel:
"""Base class for LLM models"""
def __init__(self, model: str = None, **kwargs):
self.model = model
self.config = kwargs
def call(self, request: LLMRequest):
"""
Call the model with a request.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call not implemented in this context")
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
class ModelFactory:
"""Factory for creating model instances"""
@staticmethod
def create_model(model_type: str, **kwargs):
"""
Create a model instance based on type.
This is a placeholder implementation.
"""
raise NotImplementedError("ModelFactory.create_model not implemented in this context")

97
agent/protocol/result.py Normal file
View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Any, Optional
from agent.protocol.task import Task, TaskStatus
class AgentActionType(Enum):
"""Enum representing different types of agent actions."""
TOOL_USE = "tool_use"
THINKING = "thinking"
FINAL_ANSWER = "final_answer"
@dataclass
class ToolResult:
"""
Represents the result of a tool use.
Attributes:
tool_name: Name of the tool used
input_params: Parameters passed to the tool
output: Output from the tool
status: Status of the tool execution (success/error)
error_message: Error message if the tool execution failed
execution_time: Time taken to execute the tool
"""
tool_name: str
input_params: Dict[str, Any]
output: Any
status: str
error_message: Optional[str] = None
execution_time: float = 0.0
@dataclass
class AgentAction:
"""
Represents an action taken by an agent.
Attributes:
id: Unique identifier for the action
agent_id: ID of the agent that performed the action
agent_name: Name of the agent that performed the action
action_type: Type of action (tool use, thinking, final answer)
content: Content of the action (thought content, final answer content)
tool_result: Tool use details if action_type is TOOL_USE
timestamp: When the action was performed
"""
agent_id: str
agent_name: str
action_type: AgentActionType
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
tool_result: Optional[ToolResult] = None
thought: Optional[str] = None
timestamp: float = field(default_factory=time.time)
@dataclass
class AgentResult:
"""
Represents the result of an agent's execution.
Attributes:
final_answer: The final answer provided by the agent
step_count: Number of steps taken by the agent
status: Status of the execution (success/error)
error_message: Error message if execution failed
"""
final_answer: str
step_count: int
status: str = "success"
error_message: Optional[str] = None
@classmethod
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
"""Create a successful result"""
return cls(final_answer=final_answer, step_count=step_count)
@classmethod
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
"""Create an error result"""
return cls(
final_answer=f"Error: {error_message}",
step_count=step_count,
status="error",
error_message=error_message
)
@property
def is_error(self) -> bool:
"""Check if the result represents an error"""
return self.status == "error"

96
agent/protocol/task.py Normal file
View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any, List
class TaskType(Enum):
"""Enum representing different types of tasks."""
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
FILE = "file"
MIXED = "mixed"
class TaskStatus(Enum):
"""Enum representing the status of a task."""
INIT = "init" # Initial state
PROCESSING = "processing" # In progress
COMPLETED = "completed" # Completed
FAILED = "failed" # Failed
@dataclass
class Task:
"""
Represents a task to be processed by an agent.
Attributes:
id: Unique identifier for the task
content: The primary text content of the task
type: Type of the task
status: Current status of the task
created_at: Timestamp when the task was created
updated_at: Timestamp when the task was last updated
metadata: Additional metadata for the task
images: List of image URLs or base64 encoded images
videos: List of video URLs
audios: List of audio URLs or base64 encoded audios
files: List of file URLs or paths
"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
type: TaskType = TaskType.TEXT
status: TaskStatus = TaskStatus.INIT
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
metadata: Dict[str, Any] = field(default_factory=dict)
# Media content
images: List[str] = field(default_factory=list)
videos: List[str] = field(default_factory=list)
audios: List[str] = field(default_factory=list)
files: List[str] = field(default_factory=list)
def __init__(self, content: str = "", **kwargs):
"""
Initialize a Task with content and optional keyword arguments.
Args:
content: The text content of the task
**kwargs: Additional attributes to set
"""
self.id = kwargs.get('id', str(uuid.uuid4()))
self.content = content
self.type = kwargs.get('type', TaskType.TEXT)
self.status = kwargs.get('status', TaskStatus.INIT)
self.created_at = kwargs.get('created_at', time.time())
self.updated_at = kwargs.get('updated_at', time.time())
self.metadata = kwargs.get('metadata', {})
self.images = kwargs.get('images', [])
self.videos = kwargs.get('videos', [])
self.audios = kwargs.get('audios', [])
self.files = kwargs.get('files', [])
def get_text(self) -> str:
"""
Get the text content of the task.
Returns:
The text content
"""
return self.content
def update_status(self, status: TaskStatus) -> None:
"""
Update the status of the task.
Args:
status: The new status
"""
self.status = status
self.updated_at = time.time()

31
agent/skills/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
"""
Skills module for agent system.
This module provides the framework for loading, managing, and executing skills.
Skills are markdown files with frontmatter that provide specialized instructions
for specific tasks.
"""
from agent.skills.types import (
Skill,
SkillEntry,
SkillMetadata,
SkillInstallSpec,
LoadSkillsResult,
)
from agent.skills.loader import SkillLoader
from agent.skills.manager import SkillManager
from agent.skills.service import SkillService
from agent.skills.formatter import format_skills_for_prompt
__all__ = [
"Skill",
"SkillEntry",
"SkillMetadata",
"SkillInstallSpec",
"LoadSkillsResult",
"SkillLoader",
"SkillManager",
"SkillService",
"format_skills_for_prompt",
]

230
agent/skills/config.py Normal file
View File

@@ -0,0 +1,230 @@
"""
Configuration support for skills.
"""
import os
import platform
from typing import Dict, Optional, List
from agent.skills.types import SkillEntry
def resolve_runtime_platform() -> str:
"""Get the current runtime platform."""
return platform.system().lower()
def has_binary(bin_name: str) -> bool:
"""
Check if a binary is available in PATH.
:param bin_name: Binary name to check
:return: True if binary is available
"""
import shutil
return shutil.which(bin_name) is not None
def has_any_binary(bin_names: List[str]) -> bool:
"""
Check if any of the given binaries is available.
:param bin_names: List of binary names to check
:return: True if at least one binary is available
"""
return any(has_binary(bin_name) for bin_name in bin_names)
def has_env_var(env_name: str) -> bool:
"""
Check if an environment variable is set.
:param env_name: Environment variable name
:return: True if environment variable is set
"""
return env_name in os.environ and bool(os.environ[env_name].strip())
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
"""
Get skill-specific configuration.
:param config: Global configuration dictionary
:param skill_name: Name of the skill
:return: Skill configuration or None
"""
if not config:
return None
skills_config = config.get('skills', {})
if not isinstance(skills_config, dict):
return None
entries = skills_config.get('entries', {})
if not isinstance(entries, dict):
return None
return entries.get(skill_name)
def should_include_skill(
entry: SkillEntry,
config: Optional[Dict] = None,
current_platform: Optional[str] = None,
) -> bool:
"""
Determine if a skill should be included based on requirements.
Simple rule: Skills are auto-enabled if their requirements are met.
- Has required API keys → enabled
- Missing API keys → disabled
- Wrong keys → enabled but will fail at runtime (LLM will handle error)
:param entry: SkillEntry to check
:param config: Configuration dictionary (currently unused, reserved for future)
:param current_platform: Current platform (default: auto-detect)
:return: True if skill should be included
"""
metadata = entry.metadata
# No metadata = always include (no requirements)
if not metadata:
return True
# Check platform requirements (can't work on wrong platform)
if metadata.os:
platform_name = current_platform or resolve_runtime_platform()
# Map common platform names
platform_map = {
'darwin': 'darwin',
'linux': 'linux',
'windows': 'win32',
}
normalized_platform = platform_map.get(platform_name, platform_name)
if normalized_platform not in metadata.os:
return False
# If skill has 'always: true', include it regardless of other requirements
if metadata.always:
return True
# Check requirements
if metadata.requires:
# Check required binaries (all must be present)
required_bins = metadata.requires.get('bins', [])
if required_bins:
if not all(has_binary(bin_name) for bin_name in required_bins):
return False
# Check anyBins (at least one must be present)
any_bins = metadata.requires.get('anyBins', [])
if any_bins:
if not has_any_binary(any_bins):
return False
# Check environment variables (API keys)
# All required env vars must be set
required_env = metadata.requires.get('env', [])
if required_env:
for env_name in required_env:
if not has_env_var(env_name):
return False
# Check anyEnv (at least one must be present)
any_env = metadata.requires.get('anyEnv', [])
if any_env:
if not any(has_env_var(e) for e in any_env):
return False
return True
def get_missing_requirements(
entry: SkillEntry,
current_platform: Optional[str] = None,
) -> Dict[str, List[str]]:
"""
Return a dict of missing requirements for a skill.
Empty dict means all requirements are met.
:param entry: SkillEntry to check
:param current_platform: Current platform (default: auto-detect)
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
"""
missing: Dict[str, List[str]] = {}
metadata = entry.metadata
if not metadata or not metadata.requires:
return missing
required_bins = metadata.requires.get('bins', [])
if required_bins:
missing_bins = [b for b in required_bins if not has_binary(b)]
if missing_bins:
missing['bins'] = missing_bins
any_bins = metadata.requires.get('anyBins', [])
if any_bins and not has_any_binary(any_bins):
missing['anyBins'] = any_bins
required_env = metadata.requires.get('env', [])
if required_env:
missing_env = [e for e in required_env if not has_env_var(e)]
if missing_env:
missing['env'] = missing_env
any_env = metadata.requires.get('anyEnv', [])
if any_env and not any(has_env_var(e) for e in any_env):
missing['anyEnv'] = any_env
return missing
def is_config_path_truthy(config: Dict, path: str) -> bool:
"""
Check if a config path resolves to a truthy value.
:param config: Configuration dictionary
:param path: Dot-separated path (e.g., 'skills.enabled')
:return: True if path resolves to truthy value
"""
parts = path.split('.')
current = config
for part in parts:
if not isinstance(current, dict):
return False
current = current.get(part)
if current is None:
return False
# Check if value is truthy
if isinstance(current, bool):
return current
if isinstance(current, (int, float)):
return current != 0
if isinstance(current, str):
return bool(current.strip())
return bool(current)
def resolve_config_path(config: Dict, path: str):
"""
Resolve a dot-separated config path to its value.
:param config: Configuration dictionary
:param path: Dot-separated path
:return: Value at path or None
"""
parts = path.split('.')
current = config
for part in parts:
if not isinstance(current, dict):
return None
current = current.get(part)
if current is None:
return None
return current

126
agent/skills/formatter.py Normal file
View File

@@ -0,0 +1,126 @@
"""
Skill formatter for generating prompts from skills.
"""
from typing import Dict, List
from agent.skills.types import Skill, SkillEntry
def format_skills_for_prompt(skills: List[Skill]) -> str:
"""
Format skills for inclusion in a system prompt.
Uses XML format per Agent Skills standard.
Skills with disable_model_invocation=True are excluded.
:param skills: List of skills to format
:return: Formatted prompt text
"""
# Filter out skills that should not be invoked by the model
visible_skills = [s for s in skills if not s.disable_model_invocation]
if not visible_skills:
return ""
lines = [
"",
"<available_skills>",
]
for skill in visible_skills:
lines.append(" <skill>")
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
lines.append(" </skill>")
lines.append("</available_skills>")
return "\n".join(lines)
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
"""
Format skill entries for inclusion in a system prompt.
:param entries: List of skill entries to format
:return: Formatted prompt text
"""
skills = [entry.skill for entry in entries]
return format_skills_for_prompt(skills)
def format_unavailable_skills_for_prompt(
entries: List[SkillEntry],
missing_map: Dict[str, Dict[str, List[str]]],
) -> str:
"""
Format unavailable (requires-not-met) skills as brief setup hints
so the AI can guide users to configure them.
:param entries: List of unavailable skill entries
:param missing_map: Dict mapping skill name to its missing requirements
:return: Formatted prompt text
"""
if not entries:
return ""
lines = [
"",
"<unavailable_skills>",
"The following skills are installed but not yet ready. "
"Guide the user to complete the setup when relevant.",
]
for entry in entries:
skill = entry.skill
missing = missing_map.get(skill.name, {})
missing_parts = []
for key, values in missing.items():
missing_parts.append(f"{key}: {', '.join(values)}")
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
setup_hint = _extract_setup_hint(skill)
lines.append(" <skill>")
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
if setup_hint:
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
lines.append(" </skill>")
lines.append("</unavailable_skills>")
return "\n".join(lines)
def _extract_setup_hint(skill: Skill) -> str:
"""
Extract the Setup section from SKILL.md content as a brief hint.
Returns the first few lines of the ## Setup section.
"""
content = skill.content
if not content:
return ""
import re
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
if not match:
return ""
setup_text = match.group(1).strip()
lines = setup_text.split('\n')
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
return ' '.join(hint_lines)[:300]
def _escape_xml(text: str) -> str:
"""Escape XML special characters."""
return (text
.replace('&', '&amp;')
.replace('<', '&lt;')
.replace('>', '&gt;')
.replace('"', '&quot;')
.replace("'", '&apos;'))

192
agent/skills/frontmatter.py Normal file
View File

@@ -0,0 +1,192 @@
"""
Frontmatter parsing for skills.
"""
import re
import json
from typing import Dict, Any, Optional, List
from agent.skills.types import SkillMetadata, SkillInstallSpec
def parse_frontmatter(content: str) -> Dict[str, Any]:
"""
Parse YAML-style frontmatter from markdown content.
Returns a dictionary of frontmatter fields.
"""
frontmatter = {}
# Match frontmatter block between --- markers
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
if not match:
return frontmatter
frontmatter_text = match.group(1)
# Try to use PyYAML for proper YAML parsing
try:
import yaml
frontmatter = yaml.safe_load(frontmatter_text)
if not isinstance(frontmatter, dict):
frontmatter = {}
return frontmatter
except ImportError:
# Fallback to simple parsing if PyYAML not available
pass
except Exception:
# If YAML parsing fails, fall back to simple parsing
pass
# Simple YAML-like parsing (supports key: value format only)
# This is a fallback for when PyYAML is not available
for line in frontmatter_text.split('\n'):
line = line.strip()
if not line or line.startswith('#'):
continue
if ':' in line:
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
# Try to parse as JSON if it looks like JSON
if value.startswith('{') or value.startswith('['):
try:
value = json.loads(value)
except json.JSONDecodeError:
pass
# Parse boolean values
elif value.lower() in ('true', 'false'):
value = value.lower() == 'true'
# Parse numbers
elif value.isdigit():
value = int(value)
frontmatter[key] = value
return frontmatter
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
"""
Parse skill metadata from frontmatter.
Looks for 'metadata' field containing JSON with skill configuration.
"""
metadata_raw = frontmatter.get('metadata')
if not metadata_raw:
return None
# If it's a string, try to parse as JSON
if isinstance(metadata_raw, str):
try:
metadata_raw = json.loads(metadata_raw)
except json.JSONDecodeError:
return None
if not isinstance(metadata_raw, dict):
return None
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
meta_obj = _unwrap_metadata_namespace(metadata_raw)
# Parse install specs
install_specs = []
install_raw = meta_obj.get('install', [])
if isinstance(install_raw, list):
for spec_raw in install_raw:
if not isinstance(spec_raw, dict):
continue
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
if not kind:
continue
spec = SkillInstallSpec(
kind=kind,
id=spec_raw.get('id'),
label=spec_raw.get('label'),
bins=_normalize_string_list(spec_raw.get('bins')),
os=_normalize_string_list(spec_raw.get('os')),
formula=spec_raw.get('formula'),
package=spec_raw.get('package'),
module=spec_raw.get('module'),
url=spec_raw.get('url'),
archive=spec_raw.get('archive'),
extract=spec_raw.get('extract', False),
strip_components=spec_raw.get('stripComponents'),
target_dir=spec_raw.get('targetDir'),
)
install_specs.append(spec)
# Parse requires
requires = {}
requires_raw = meta_obj.get('requires', {})
if isinstance(requires_raw, dict):
for key, value in requires_raw.items():
requires[key] = _normalize_string_list(value)
return SkillMetadata(
always=meta_obj.get('always', False),
default_enabled=meta_obj.get('default_enabled', True),
skill_key=meta_obj.get('skillKey'),
primary_env=meta_obj.get('primaryEnv'),
emoji=meta_obj.get('emoji'),
homepage=meta_obj.get('homepage'),
os=_normalize_string_list(meta_obj.get('os')),
requires=requires,
install=install_specs,
)
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
"""
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
Otherwise return the original dict unchanged.
"""
keys = set(metadata_raw.keys())
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
if len(ns_keys) == 1 and len(keys) == 1:
ns = ns_keys.pop()
inner = metadata_raw[ns]
if isinstance(inner, dict):
return inner
return metadata_raw
def _normalize_string_list(value: Any) -> List[str]:
"""Normalize a value to a list of strings."""
if not value:
return []
if isinstance(value, list):
return [str(v).strip() for v in value if v]
if isinstance(value, str):
return [v.strip() for v in value.split(',') if v.strip()]
return []
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
"""Parse a boolean value from frontmatter."""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ('true', '1', 'yes', 'on')
return default
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
"""Get a frontmatter value as a string."""
value = frontmatter.get(key)
return str(value) if value is not None else None

277
agent/skills/loader.py Normal file
View File

@@ -0,0 +1,277 @@
"""
Skill loader for discovering and loading skills from directories.
"""
import os
from pathlib import Path
from typing import List, Optional, Dict
from common.log import logger
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
class SkillLoader:
"""Loads skills from various directories."""
def __init__(self):
pass
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
"""
Load skills from a directory.
Discovery rules:
- Direct .md files in the root directory
- Recursive SKILL.md files under subdirectories
:param dir_path: Directory path to scan
:param source: Source identifier ('builtin' or 'custom')
:return: LoadSkillsResult with skills and diagnostics
"""
skills = []
diagnostics = []
if not os.path.exists(dir_path):
diagnostics.append(f"Directory does not exist: {dir_path}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
if not os.path.isdir(dir_path):
diagnostics.append(f"Path is not a directory: {dir_path}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
# Load skills from root-level .md files and subdirectories
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
return result
def _load_skills_recursive(
self,
dir_path: str,
source: str,
include_root_files: bool = False
) -> LoadSkillsResult:
"""
Recursively load skills from a directory.
:param dir_path: Directory to scan
:param source: Source identifier
:param include_root_files: Whether to include root-level .md files
:return: LoadSkillsResult
"""
skills = []
diagnostics = []
try:
entries = os.listdir(dir_path)
except Exception as e:
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
for entry in entries:
# Skip hidden files and directories
if entry.startswith('.'):
continue
# Skip common non-skill directories
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
continue
full_path = os.path.join(dir_path, entry)
# Handle directories
if os.path.isdir(full_path):
# Recursively scan subdirectories
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
skills.extend(sub_result.skills)
diagnostics.extend(sub_result.diagnostics)
continue
# Handle files
if not os.path.isfile(full_path):
continue
# Check if this is a skill file
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
is_skill_md = not include_root_files and entry == 'SKILL.md'
if not (is_root_md or is_skill_md):
continue
# Load the skill
skill_result = self._load_skill_from_file(full_path, source)
if skill_result.skills:
skills.extend(skill_result.skills)
diagnostics.extend(skill_result.diagnostics)
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
"""
Load a single skill from a markdown file.
:param file_path: Path to the skill markdown file
:param source: Source identifier
:return: LoadSkillsResult
"""
diagnostics = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
# Parse frontmatter
frontmatter = parse_frontmatter(content)
# Get skill name and description
skill_dir = os.path.dirname(file_path)
parent_dir_name = os.path.basename(skill_dir)
name = frontmatter.get('name', parent_dir_name)
description = frontmatter.get('description', '')
# Normalize name (handle both string and list)
if isinstance(name, list):
name = name[0] if name else parent_dir_name
elif not isinstance(name, str):
name = str(name) if name else parent_dir_name
# Normalize description (handle both string and list)
if isinstance(description, list):
description = ' '.join(str(d) for d in description if d)
elif not isinstance(description, str):
description = str(description) if description else ''
# Special handling for linkai-agent: dynamically load apps from config.json
if name == 'linkai-agent':
description = self._load_linkai_agent_description(skill_dir, description)
if not description or not description.strip():
diagnostics.append(f"Skill {name} has no description: {file_path}")
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
# Parse disable-model-invocation flag
disable_model_invocation = parse_boolean_value(
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
default=False
)
# Create skill object
skill = Skill(
name=name,
description=description,
file_path=file_path,
base_dir=skill_dir,
source=source,
content=content,
disable_model_invocation=disable_model_invocation,
frontmatter=frontmatter,
)
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
"""
Dynamically load LinkAI agent description from config.json
:param skill_dir: Skill directory
:param default_description: Default description from SKILL.md
:return: Dynamic description with app list
"""
import json
config_path = os.path.join(skill_dir, "config.json")
if not os.path.exists(config_path):
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
return ""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
apps = config.get("apps", [])
if not apps:
return default_description
# Build dynamic description with app details
app_descriptions = "; ".join([
f"{app['app_name']}({app['app_code']}: {app['app_description']})"
for app in apps
])
return f"Call LinkAI apps/workflows. {app_descriptions}"
except Exception as e:
logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
return default_description
def load_all_skills(
self,
builtin_dir: Optional[str] = None,
custom_dir: Optional[str] = None,
) -> Dict[str, SkillEntry]:
"""
Load skills from builtin and custom directories.
Precedence (lowest to highest):
1. builtin — project root ``skills/``, shipped with the codebase
2. custom — workspace ``skills/``, installed via cloud console or skill creator
Same-name custom skills override builtin ones.
:param builtin_dir: Built-in skills directory
:param custom_dir: Custom skills directory
:return: Dictionary mapping skill name to SkillEntry
"""
skill_map: Dict[str, SkillEntry] = {}
all_diagnostics = []
# Load builtin skills (lower precedence)
if builtin_dir and os.path.exists(builtin_dir):
result = self.load_skills_from_dir(builtin_dir, source='builtin')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Load custom skills (higher precedence, overrides builtin)
if custom_dir and os.path.exists(custom_dir):
result = self.load_skills_from_dir(custom_dir, source='custom')
all_diagnostics.extend(result.diagnostics)
for skill in result.skills:
entry = self._create_skill_entry(skill)
skill_map[skill.name] = entry
# Log diagnostics
if all_diagnostics:
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
for diag in all_diagnostics[:5]:
logger.debug(f" - {diag}")
logger.debug(f"Loaded {len(skill_map)} skills total")
return skill_map
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
"""
Create a SkillEntry from a Skill with parsed metadata.
:param skill: The skill to create an entry for
:return: SkillEntry with metadata
"""
metadata = parse_metadata(skill.frontmatter)
# Parse user-invocable flag
user_invocable = parse_boolean_value(
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
default=True
)
return SkillEntry(
skill=skill,
metadata=metadata,
user_invocable=user_invocable,
)

357
agent/skills/manager.py Normal file
View File

@@ -0,0 +1,357 @@
"""
Skill manager for managing skill lifecycle and operations.
"""
import os
import json
from typing import Dict, List, Optional
from pathlib import Path
from common.log import logger
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
from agent.skills.loader import SkillLoader
from agent.skills.formatter import format_skill_entries_for_prompt
SKILLS_CONFIG_FILE = "skills_config.json"
class SkillManager:
"""Manages skills for an agent."""
def __init__(
self,
builtin_dir: Optional[str] = None,
custom_dir: Optional[str] = None,
config: Optional[Dict] = None,
):
"""
Initialize the skill manager.
:param builtin_dir: Built-in skills directory (project root ``skills/``)
:param custom_dir: Custom skills directory (workspace ``skills/``)
:param config: Configuration dictionary
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
self.config = config or {}
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
# skills_config: full skill metadata keyed by name
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
self.skills_config: Dict[str, dict] = {}
self.loader = SkillLoader()
self.skills: Dict[str, SkillEntry] = {}
# Load skills on initialization
self.refresh_skills()
def refresh_skills(self):
"""Reload all skills from builtin and custom directories, then sync config."""
self.skills = self.loader.load_all_skills(
builtin_dir=self.builtin_dir,
custom_dir=self.custom_dir,
)
self._sync_skills_config()
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
# ------------------------------------------------------------------
# skills_config.json management
# ------------------------------------------------------------------
def _load_skills_config(self) -> Dict[str, dict]:
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
if not os.path.exists(self._skills_config_path):
return {}
try:
with open(self._skills_config_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
return data
except Exception as e:
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
return {}
def _save_skills_config(self):
"""Persist skills_config to custom_dir/skills_config.json."""
os.makedirs(self.custom_dir, exist_ok=True)
try:
with open(self._skills_config_path, "w", encoding="utf-8") as f:
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
except Exception as e:
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
def _sync_skills_config(self):
"""
Merge directory-scanned skills with the persisted config file.
- New skills: use metadata.default_enabled as initial enabled state.
- Existing skills: preserve their persisted enabled state.
- Skills that no longer exist on disk are removed.
- name/description/source are always refreshed from the latest scan.
"""
saved = self._load_skills_config()
merged: Dict[str, dict] = {}
for name, entry in self.skills.items():
skill = entry.skill
prev = saved.get(name, {})
category = prev.get("category", "skill")
if name in saved:
enabled = prev.get("enabled", True)
else:
enabled = entry.metadata.default_enabled if entry.metadata else True
entry_dict = {
"name": name,
"description": skill.description,
"source": prev.get("source") or skill.source,
"enabled": enabled,
"category": category,
}
display_name = prev.get("display_name")
if display_name:
entry_dict["display_name"] = display_name
merged[name] = entry_dict
self.skills_config = merged
self._save_skills_config()
def is_skill_enabled(self, name: str) -> bool:
"""
Check if a skill is enabled according to skills_config.
:param name: skill name
:return: True if enabled (default True if not in config)
"""
entry = self.skills_config.get(name)
if entry is None:
return True
return entry.get("enabled", True)
def set_skill_enabled(self, name: str, enabled: bool):
"""
Set a skill's enabled state and persist.
:param name: skill name
:param enabled: True to enable, False to disable
"""
if name not in self.skills_config:
raise ValueError(f"skill '{name}' not found in config")
self.skills_config[name]["enabled"] = enabled
self._save_skills_config()
def get_skills_config(self) -> Dict[str, dict]:
"""
Return the full skills_config dict (for query API).
:return: copy of skills_config
"""
return dict(self.skills_config)
def get_skill(self, name: str) -> Optional[SkillEntry]:
"""
Get a skill by name.
:param name: Skill name
:return: SkillEntry or None if not found
"""
return self.skills.get(name)
def list_skills(self) -> List[SkillEntry]:
"""
Get all loaded skills.
:return: List of all skill entries
"""
return list(self.skills.values())
@staticmethod
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
"""Normalize a skill_filter list into a flat list of stripped names."""
if skill_filter is None:
return None
normalized = []
for item in skill_filter:
if isinstance(item, str):
name = item.strip()
if name:
normalized.append(name)
elif isinstance(item, list):
for subitem in item:
if isinstance(subitem, str):
name = subitem.strip()
if name:
normalized.append(name)
return normalized or None
def filter_skills(
self,
skill_filter: Optional[List[str]] = None,
include_disabled: bool = False,
) -> List[SkillEntry]:
"""
Filter skills that are eligible (enabled + requirements met).
:param skill_filter: List of skill names to include (None = all)
:param include_disabled: Whether to include disabled skills
:return: Filtered list of eligible skill entries
"""
from agent.skills.config import should_include_skill
entries = list(self.skills.values())
entries = [e for e in entries if should_include_skill(e, self.config)]
normalized = self._normalize_skill_filter(skill_filter)
if normalized is not None:
entries = [e for e in entries if e.skill.name in normalized]
if not include_disabled:
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
return entries
def filter_unavailable_skills(
self,
skill_filter: Optional[List[str]] = None,
) -> tuple:
"""
Find skills that are enabled but have unmet requirements.
:param skill_filter: Optional list of skill names to include
:return: Tuple of (entries, missing_map) where missing_map maps
skill name to its missing requirements dict
"""
from agent.skills.config import should_include_skill, get_missing_requirements
entries = list(self.skills.values())
# Only enabled skills
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
normalized = self._normalize_skill_filter(skill_filter)
if normalized is not None:
entries = [e for e in entries if e.skill.name in normalized]
# Keep only those that fail should_include_skill (requirements not met)
unavailable = []
missing_map: Dict[str, dict] = {}
for e in entries:
if not should_include_skill(e, self.config):
missing = get_missing_requirements(e)
if missing:
unavailable.append(e)
missing_map[e.skill.name] = missing
return unavailable, missing_map
def build_skills_prompt(
self,
skill_filter: Optional[List[str]] = None,
) -> str:
"""
Build a formatted prompt containing available skills
and brief hints for unavailable ones.
:param skill_filter: Optional list of skill names to include
:return: Formatted skills prompt
"""
from common.log import logger
from agent.skills.formatter import format_unavailable_skills_for_prompt
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
if eligible:
skill_names = [e.skill.name for e in eligible]
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
result = format_skill_entries_for_prompt(eligible)
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
if unavailable:
unavailable_names = [e.skill.name for e in unavailable]
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
return result
def build_skill_snapshot(
self,
skill_filter: Optional[List[str]] = None,
version: Optional[int] = None,
) -> SkillSnapshot:
"""
Build a snapshot of skills for a specific run.
:param skill_filter: Optional list of skill names to include
:param version: Optional version number for the snapshot
:return: SkillSnapshot
"""
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
prompt = format_skill_entries_for_prompt(entries)
skills_info = []
resolved_skills = []
for entry in entries:
skills_info.append({
'name': entry.skill.name,
'primary_env': entry.metadata.primary_env if entry.metadata else None,
})
resolved_skills.append(entry.skill)
return SkillSnapshot(
prompt=prompt,
skills=skills_info,
resolved_skills=resolved_skills,
version=version,
)
def sync_skills_to_workspace(self, target_workspace_dir: str):
"""
Sync all loaded skills to a target workspace directory.
This is useful for sandbox environments where skills need to be copied.
:param target_workspace_dir: Target workspace directory
"""
import shutil
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
# Remove existing skills directory
if os.path.exists(target_skills_dir):
shutil.rmtree(target_skills_dir)
# Create new skills directory
os.makedirs(target_skills_dir, exist_ok=True)
# Copy each skill
for entry in self.skills.values():
skill_name = entry.skill.name
source_dir = entry.skill.base_dir
target_dir = os.path.join(target_skills_dir, skill_name)
try:
shutil.copytree(source_dir, target_dir)
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
except Exception as e:
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
"""
Get a skill by its skill key (which may differ from name).
:param skill_key: Skill key to look up
:return: SkillEntry or None
"""
for entry in self.skills.values():
if entry.metadata and entry.metadata.skill_key == skill_key:
return entry
if entry.skill.name == skill_key:
return entry
return None

285
agent/skills/service.py Normal file
View File

@@ -0,0 +1,285 @@
"""
Skill service for handling skill CRUD operations.
This service provides a unified interface for managing skills, which can be
called from the cloud control client (LinkAI), the local web console, or any
other management entry point.
"""
import os
import shutil
import zipfile
import tempfile
from typing import Dict, List, Optional
from common.log import logger
from agent.skills.types import Skill, SkillEntry
from agent.skills.manager import SkillManager
try:
import requests
except ImportError:
requests = None
class SkillService:
"""
High-level service for skill lifecycle management.
Wraps SkillManager and provides network-aware operations such as
downloading skill files from remote URLs.
"""
def __init__(self, skill_manager: SkillManager):
"""
:param skill_manager: The SkillManager instance to operate on
"""
self.manager = skill_manager
# ------------------------------------------------------------------
# query
# ------------------------------------------------------------------
def query(self) -> List[dict]:
"""
Query all skills and return a serialisable list.
Reads from skills_config.json (refreshes from disk if needed).
:return: list of skill info dicts
"""
self.manager.refresh_skills()
config = self.manager.get_skills_config()
result = list(config.values())
logger.info(f"[SkillService] query: {len(result)} skills found")
return result
# ------------------------------------------------------------------
# add / install
# ------------------------------------------------------------------
def add(self, payload: dict) -> None:
"""
Add (install) a skill from a remote payload.
Supported payload types:
1. ``type: "url"`` download individual files::
{
"name": "web_search",
"type": "url",
"enabled": true,
"files": [
{"url": "https://...", "path": "README.md"},
{"url": "https://...", "path": "scripts/main.py"}
]
}
2. ``type: "package"`` download a zip archive and extract::
{
"name": "plugin-custom-tool",
"type": "package",
"category": "skills",
"enabled": true,
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
}
:param payload: skill add payload from server
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
payload_type = payload.get("type", "url")
if payload_type == "package":
self._add_package(name, payload)
else:
self._add_url(name, payload)
self.manager.refresh_skills()
category = payload.get("category")
if category and name in self.manager.skills_config:
self.manager.skills_config[name]["category"] = category
self.manager._save_skills_config()
def _add_url(self, name: str, payload: dict) -> None:
"""Install a skill by downloading individual files."""
files = payload.get("files", [])
if not files:
raise ValueError("skill files list is empty")
skill_dir = os.path.join(self.manager.custom_dir, name)
tmp_dir = skill_dir + ".tmp"
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir, exist_ok=True)
try:
for file_info in files:
url = file_info.get("url")
rel_path = file_info.get("path")
if not url or not rel_path:
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
continue
dest = os.path.join(tmp_dir, rel_path)
self._download_file(url, dest)
except Exception:
shutil.rmtree(tmp_dir, ignore_errors=True)
raise
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
os.rename(tmp_dir, skill_dir)
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
def _add_package(self, name: str, payload: dict) -> None:
"""
Install a skill by downloading a zip archive and extracting it.
If the archive contains a single top-level directory, that directory
is used as the skill folder directly; otherwise a new directory named
after the skill is created to hold the extracted contents.
"""
files = payload.get("files", [])
if not files or not files[0].get("url"):
raise ValueError("package url is required")
url = files[0]["url"]
skill_dir = os.path.join(self.manager.custom_dir, name)
with tempfile.TemporaryDirectory() as tmp_dir:
zip_path = os.path.join(tmp_dir, "package.zip")
self._download_file(url, zip_path)
if not zipfile.is_zipfile(zip_path):
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
extract_dir = os.path.join(tmp_dir, "extracted")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(extract_dir)
# Determine the actual content root.
# If the zip has a single top-level directory, use its contents
# so the skill folder is clean (no extra nesting).
top_items = [
item for item in os.listdir(extract_dir)
if not item.startswith(".")
]
if len(top_items) == 1:
single = os.path.join(extract_dir, top_items[0])
if os.path.isdir(single):
extract_dir = single
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
shutil.copytree(extract_dir, skill_dir)
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
# ------------------------------------------------------------------
# open / close (enable / disable)
# ------------------------------------------------------------------
def open(self, payload: dict) -> None:
"""
Enable a skill by name.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
self.manager.set_skill_enabled(name, enabled=True)
logger.info(f"[SkillService] open: skill '{name}' enabled")
def close(self, payload: dict) -> None:
"""
Disable a skill by name.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
self.manager.set_skill_enabled(name, enabled=False)
logger.info(f"[SkillService] close: skill '{name}' disabled")
# ------------------------------------------------------------------
# delete
# ------------------------------------------------------------------
def delete(self, payload: dict) -> None:
"""
Delete a skill by removing its directory entirely.
:param payload: {"name": "skill_name"}
"""
name = payload.get("name")
if not name:
raise ValueError("skill name is required")
skill_dir = os.path.join(self.manager.custom_dir, name)
if os.path.exists(skill_dir):
shutil.rmtree(skill_dir)
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
else:
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
# Refresh will remove the deleted skill from config automatically
self.manager.refresh_skills()
logger.info(f"[SkillService] delete: skill '{name}' deleted")
# ------------------------------------------------------------------
# dispatch - single entry point for protocol messages
# ------------------------------------------------------------------
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
"""
Dispatch a skill management action and return a protocol-compatible
response dict.
:param action: one of query / add / open / close / delete
:param payload: action-specific payload (may be None for query)
:return: dict with action, code, message, payload
"""
payload = payload or {}
try:
if action == "query":
result_payload = self.query()
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
elif action == "add":
self.add(payload)
elif action == "open":
self.open(payload)
elif action == "close":
self.close(payload)
elif action == "delete":
self.delete(payload)
else:
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
return {"action": action, "code": 200, "message": "success", "payload": None}
except Exception as e:
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
return {"action": action, "code": 500, "message": str(e), "payload": None}
# ------------------------------------------------------------------
# internal helpers
# ------------------------------------------------------------------
@staticmethod
def _download_file(url: str, dest: str):
"""
Download a file from *url* and save to *dest*.
:param url: remote file URL
:param dest: local destination path
"""
if requests is None:
raise RuntimeError("requests library is required for downloading skill files")
dest_dir = os.path.dirname(dest)
if dest_dir:
os.makedirs(dest_dir, exist_ok=True)
resp = requests.get(url, timeout=60)
resp.raise_for_status()
with open(dest, "wb") as f:
f.write(resp.content)
logger.debug(f"[SkillService] downloaded {url} -> {dest}")

76
agent/skills/types.py Normal file
View File

@@ -0,0 +1,76 @@
"""
Type definitions for skills system.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
@dataclass
class SkillInstallSpec:
"""Specification for installing skill dependencies."""
kind: str # brew, pip, npm, download, etc.
id: Optional[str] = None
label: Optional[str] = None
bins: List[str] = field(default_factory=list)
os: List[str] = field(default_factory=list)
formula: Optional[str] = None # for brew
package: Optional[str] = None # for pip/npm
module: Optional[str] = None
url: Optional[str] = None # for download
archive: Optional[str] = None
extract: bool = False
strip_components: Optional[int] = None
target_dir: Optional[str] = None
@dataclass
class SkillMetadata:
"""Metadata for a skill from frontmatter."""
always: bool = False # Always include this skill
default_enabled: bool = True # Initial enabled state when first discovered
skill_key: Optional[str] = None # Override skill key
primary_env: Optional[str] = None # Primary environment variable
emoji: Optional[str] = None
homepage: Optional[str] = None
os: List[str] = field(default_factory=list) # Supported OS platforms
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
install: List[SkillInstallSpec] = field(default_factory=list)
@dataclass
class Skill:
"""Represents a skill loaded from a markdown file."""
name: str
description: str
file_path: str
base_dir: str
source: str # builtin or custom
content: str # Full markdown content
disable_model_invocation: bool = False
frontmatter: Dict[str, Any] = field(default_factory=dict)
@dataclass
class SkillEntry:
"""A skill with parsed metadata."""
skill: Skill
metadata: Optional[SkillMetadata] = None
user_invocable: bool = True # Can users invoke this skill directly
@dataclass
class LoadSkillsResult:
"""Result of loading skills from a directory."""
skills: List[Skill]
diagnostics: List[str] = field(default_factory=list)
@dataclass
class SkillSnapshot:
"""Snapshot of skills for a specific run."""
prompt: str # Formatted prompt text
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
resolved_skills: List[Skill] = field(default_factory=list)
version: Optional[int] = None

132
agent/tools/__init__.py Normal file
View File

@@ -0,0 +1,132 @@
# Import base tool
from agent.tools.base_tool import BaseTool
from agent.tools.tool_manager import ToolManager
# Import file operation tools
from agent.tools.read.read import Read
from agent.tools.write.write import Write
from agent.tools.edit.edit import Edit
from agent.tools.bash.bash import Bash
from agent.tools.ls.ls import Ls
from agent.tools.send.send import Send
# Import memory tools
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
# Import tools with optional dependencies
def _import_optional_tools():
"""Import tools that have optional dependencies"""
from common.log import logger
tools = {}
# EnvConfig Tool (requires python-dotenv)
try:
from agent.tools.env_config.env_config import EnvConfig
tools['EnvConfig'] = EnvConfig
except ImportError as e:
logger.error(
f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
f" To enable environment variable management, run:\n"
f" pip install python-dotenv>=1.0.0"
)
except Exception as e:
logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
# Scheduler Tool (requires croniter)
try:
from agent.tools.scheduler.scheduler_tool import SchedulerTool
tools['SchedulerTool'] = SchedulerTool
except ImportError as e:
logger.error(
f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
f" To enable scheduled tasks, run:\n"
f" pip install croniter>=2.0.0"
)
except Exception as e:
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
# WebSearch Tool (conditionally loaded based on API key availability at init time)
try:
from agent.tools.web_search.web_search import WebSearch
tools['WebSearch'] = WebSearch
except ImportError as e:
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] WebSearch failed to load: {e}")
# WebFetch Tool
try:
from agent.tools.web_fetch.web_fetch import WebFetch
tools['WebFetch'] = WebFetch
except ImportError as e:
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] WebFetch failed to load: {e}")
# Vision Tool (conditionally loaded based on API key availability)
try:
from agent.tools.vision.vision import Vision
tools['Vision'] = Vision
except ImportError as e:
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
except Exception as e:
logger.error(f"[Tools] Vision failed to load: {e}")
return tools
# Load optional tools
_optional_tools = _import_optional_tools()
EnvConfig = _optional_tools.get('EnvConfig')
SchedulerTool = _optional_tools.get('SchedulerTool')
WebSearch = _optional_tools.get('WebSearch')
WebFetch = _optional_tools.get('WebFetch')
Vision = _optional_tools.get('Vision')
GoogleSearch = _optional_tools.get('GoogleSearch')
FileSave = _optional_tools.get('FileSave')
Terminal = _optional_tools.get('Terminal')
# BrowserTool (requires playwright)
def _import_browser_tool():
from common.log import logger
try:
from agent.tools.browser.browser_tool import BrowserTool
return BrowserTool
except ImportError as e:
logger.info(
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
f" To enable browser tool, run:\n"
f" pip install playwright\n"
f" playwright install chromium"
)
return None
except Exception as e:
logger.error(f"[Tools] BrowserTool failed to load: {e}")
return None
BrowserTool = _import_browser_tool()
# Export all tools (including optional ones that might be None)
__all__ = [
'BaseTool',
'ToolManager',
'Read',
'Write',
'Edit',
'Bash',
'Ls',
'Send',
'MemorySearchTool',
'MemoryGetTool',
'EnvConfig',
'SchedulerTool',
'WebSearch',
'WebFetch',
'Vision',
'BrowserTool',
]
"""
Tools module for Agent.
"""

99
agent/tools/base_tool.py Normal file
View File

@@ -0,0 +1,99 @@
from enum import Enum
from typing import Any, Optional
from common.log import logger
import copy
class ToolStage(Enum):
"""Enum representing tool decision stages"""
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
class ToolResult:
"""Tool execution result"""
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
self.status = status
self.result = result
self.ext_data = ext_data
@staticmethod
def success(result, ext_data: Any = None):
return ToolResult(status="success", result=result, ext_data=ext_data)
@staticmethod
def fail(result, ext_data: Any = None):
return ToolResult(status="error", result=result, ext_data=ext_data)
class BaseTool:
"""Base class for all tools."""
# Default decision stage is pre-process
stage = ToolStage.PRE_PROCESS
# Class attributes must be inherited
name: str = "base_tool"
description: str = "Base tool"
params: dict = {} # Store JSON Schema
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
@classmethod
def get_json_schema(cls) -> dict:
"""Get the standard description of the tool"""
return {
"name": cls.name,
"description": cls.description,
"parameters": cls.params
}
def execute_tool(self, params: dict) -> ToolResult:
try:
return self.execute(params)
except Exception as e:
logger.error(e)
def execute(self, params: dict) -> ToolResult:
"""Specific logic to be implemented by subclasses"""
raise NotImplementedError
@classmethod
def _parse_schema(cls) -> dict:
"""Convert JSON Schema to Pydantic fields"""
fields = {}
for name, prop in cls.params["properties"].items():
# Convert JSON Schema types to Python types
type_map = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": list,
"object": dict
}
fields[name] = (
type_map[prop["type"]],
prop.get("default", ...)
)
return fields
def should_auto_execute(self, context) -> bool:
"""
Determine if this tool should be automatically executed based on context.
:param context: The agent context
:return: True if the tool should be executed, False otherwise
"""
# Only tools in post-process stage will be automatically executed
return self.stage == ToolStage.POST_PROCESS
def close(self):
"""
Close any resources used by the tool.
This method should be overridden by tools that need to clean up resources
such as browser connections, file handles, etc.
By default, this method does nothing.
"""
pass

View File

@@ -0,0 +1,3 @@
from .bash import Bash
__all__ = ['Bash']

291
agent/tools/bash/bash.py Normal file
View File

@@ -0,0 +1,291 @@
"""
Bash tool - Execute bash commands
"""
import os
import re
import sys
import subprocess
import tempfile
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
from common.log import logger
from common.utils import expand_path
class Bash(BaseTool):
"""Tool for executing bash commands"""
name: str = "bash"
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
SAFETY:
- Freely create/modify/delete files within the workspace
- For destructive and out-of-workspace commands, explain and confirm first"""
params: dict = {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (optional, default: 30)"
}
},
"required": ["command"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Ensure working directory exists
if not os.path.exists(self.cwd):
os.makedirs(self.cwd, exist_ok=True)
self.default_timeout = self.config.get("timeout", 30)
# Enable safety mode by default (can be disabled in config)
self.safety_mode = self.config.get("safety_mode", True)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute a bash command
:param args: Dictionary containing the command and optional timeout
:return: Command output or error
"""
command = args.get("command", "").strip()
timeout = args.get("timeout", self.default_timeout)
if not command:
return ToolResult.fail("Error: command parameter is required")
# Security check: Prevent accessing sensitive config files
if "~/.cow/.env" in command or "~/.cow" in command:
return ToolResult.fail(
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
)
# Optional safety check - only warn about extremely dangerous commands
if self.safety_mode:
warning = self._get_safety_warning(command)
if warning:
return ToolResult.fail(
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
try:
# Prepare environment with .env file variables
env = os.environ.copy()
# Load environment variables from ~/.cow/.env if it exists
env_file = expand_path("~/.cow/.env")
dotenv_vars = {}
if os.path.exists(env_file):
try:
from dotenv import dotenv_values
dotenv_vars = dotenv_values(env_file)
env.update(dotenv_vars)
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
except ImportError:
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
except Exception as e:
logger.debug(f"[Bash] Failed to load .env: {e}")
# getuid() only exists on Unix-like systems
if hasattr(os, 'getuid'):
logger.debug(f"[Bash] Process UID: {os.getuid()}")
else:
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
# On Windows, convert $VAR references to %VAR% for cmd.exe
if sys.platform == "win32":
env["PYTHONIOENCODING"] = "utf-8"
command = self._convert_env_vars_for_windows(command, dotenv_vars)
if command and not command.strip().lower().startswith("chcp"):
command = f"chcp 65001 >nul 2>&1 && {command}"
# Execute command with inherited environment variables
result = subprocess.run(
command,
shell=True,
cwd=self.cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="replace",
timeout=timeout,
env=env
)
logger.debug(f"[Bash] Exit code: {result.returncode}")
logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
# Workaround for exit code 126 with no output
if result.returncode == 126 and not result.stdout and not result.stderr:
logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
# Try using argument list instead of shell=True
import shlex
try:
parts = shlex.split(command)
if len(parts) > 0:
logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
retry_result = subprocess.run(
parts,
cwd=self.cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="replace",
timeout=timeout,
env=env
)
logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
# If retry succeeded, use retry result
if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
result = retry_result
else:
# Both attempts failed - check if this is openai-image-vision skill
if 'openai-image-vision' in command or 'vision.sh' in command:
# Create a mock result with helpful error message
from types import SimpleNamespace
result = SimpleNamespace(
returncode=1,
stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
stderr=''
)
logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
except Exception as retry_err:
logger.warning(f"[Bash] Retry failed: {retry_err}")
# Combine stdout and stderr
output = result.stdout
if result.stderr:
output += "\n" + result.stderr
# Check if we need to save full output to temp file
temp_file_path = None
total_bytes = len(output.encode('utf-8'))
if total_bytes > DEFAULT_MAX_BYTES:
# Save full output to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
f.write(output)
temp_file_path = f.name
# Apply tail truncation
truncation = truncate_tail(output)
output_text = truncation.content or "(no output)"
# Build result
details = {}
if truncation.truncated:
details["truncation"] = truncation.to_dict()
if temp_file_path:
details["full_output_path"] = temp_file_path
# Build notice
start_line = truncation.total_lines - truncation.output_lines + 1
end_line = truncation.total_lines
if truncation.last_line_partial:
# Edge case: last line alone > 30KB
last_line = output.split('\n')[-1] if output else ""
last_line_size = format_size(len(last_line.encode('utf-8')))
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
elif truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
else:
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
# Check exit code
if result.returncode != 0:
output_text += f"\n\nCommand exited with code {result.returncode}"
return ToolResult.fail({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
return ToolResult.success({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
except Exception as e:
return ToolResult.fail(f"Error executing command: {str(e)}")
def _get_safety_warning(self, command: str) -> str:
"""
Get safety warning for potentially dangerous commands
Only warns about extremely dangerous system-level operations
:param command: Command to check
:return: Warning message if dangerous, empty string if safe
"""
cmd_lower = command.lower().strip()
# Only block extremely dangerous system operations
dangerous_patterns = [
# System shutdown/reboot
("shutdown", "This command will shut down the system"),
("reboot", "This command will reboot the system"),
("halt", "This command will halt the system"),
("poweroff", "This command will power off the system"),
# Critical system modifications
("rm -rf /", "This command will delete the entire filesystem"),
("rm -rf /*", "This command will delete the entire filesystem"),
("dd if=/dev/zero", "This command can destroy disk data"),
("mkfs", "This command will format a filesystem, destroying all data"),
("fdisk", "This command modifies disk partitions"),
# User/system management (only if targeting system users)
("userdel root", "This command will delete the root user"),
("passwd root", "This command will change the root password"),
]
for pattern, warning in dangerous_patterns:
if pattern in cmd_lower:
return warning
# Check for recursive deletion outside workspace
if "rm" in cmd_lower and "-rf" in cmd_lower:
# Allow deletion within current workspace
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
# Check if targeting system directories
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
if any(sysdir in cmd_lower for sysdir in system_dirs):
return "This command will recursively delete system directories"
return "" # No warning needed
@staticmethod
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
"""
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
Only converts variables loaded from .env (user-configured API keys etc.)
to avoid breaking $PATH, jq expressions, regex, etc.
"""
if not dotenv_vars:
return command
def replace_match(m):
var_name = m.group(1) or m.group(2)
if var_name in dotenv_vars:
return f"%{var_name}%"
return m.group(0)
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)

View File

@@ -0,0 +1,3 @@
from agent.tools.browser.browser_tool import BrowserTool
__all__ = ["BrowserTool"]

View File

@@ -0,0 +1,708 @@
"""
Browser service - Playwright wrapper managing browser lifecycle and page operations.
All Playwright calls run on a dedicated background thread so that callers from
any worker thread can safely use the service. An idle-timeout mechanism
automatically shuts down the browser (and its thread) after a configurable
period of inactivity to free resources.
"""
import os
import sys
import uuid
import queue
import threading
from typing import Optional, Dict, Any, List, Callable
from common.log import logger
try:
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
_HAS_PLAYWRIGHT = True
except ImportError:
_HAS_PLAYWRIGHT = False
# ---------------------------------------------------------------------------
# Snapshot DOM helpers
# ---------------------------------------------------------------------------
# Tags that typically carry useful content for an agent
_INTERACTIVE_TAGS = {
"a", "button", "input", "textarea", "select", "option",
"label", "details", "summary",
}
_SEMANTIC_TAGS = {
"h1", "h2", "h3", "h4", "h5", "h6",
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
"nav", "main", "article", "section", "header", "footer", "form", "table",
"img", "video", "audio",
}
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
_SNAPSHOT_JS = """
() => {
const KEEP = new Set(%s);
const INTERACTIVE = new Set(%s);
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
let refCounter = 0;
const refMap = {};
function visible(el) {
if (!(el instanceof HTMLElement)) return true;
const st = window.getComputedStyle(el);
if (st.display === "none" || st.visibility === "hidden") return false;
if (parseFloat(st.opacity) === 0) return false;
return true;
}
function walk(node) {
if (node.nodeType === Node.TEXT_NODE) {
const t = node.textContent.trim();
return t ? t : null;
}
if (node.nodeType !== Node.ELEMENT_NODE) return null;
const tag = node.tagName.toLowerCase();
if (SKIP.has(tag)) return null;
if (!visible(node)) return null;
const children = [];
for (const ch of node.childNodes) {
const r = walk(ch);
if (r !== null) {
if (typeof r === "string") children.push(r);
else children.push(r);
}
}
const keep = KEEP.has(tag);
if (!keep) {
// Unwrap: promote children
if (children.length === 0) return null;
if (children.length === 1) return children[0];
return children;
}
const obj = { tag };
if (INTERACTIVE.has(tag)) {
refCounter++;
obj.ref = refCounter;
refMap[refCounter] = node;
}
// Attributes
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
if (tag === "img") {
obj.alt = node.alt || "";
obj.src = node.getAttribute("src") || "";
}
if (tag === "input" || tag === "textarea" || tag === "select") {
obj.type = node.type || "text";
obj.name = node.name || undefined;
obj.value = node.value || undefined;
obj.placeholder = node.placeholder || undefined;
if (node.disabled) obj.disabled = true;
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
}
if (tag === "button") {
if (node.disabled) obj.disabled = true;
}
if (tag === "option") {
obj.value = node.value;
if (node.selected) obj.selected = true;
}
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
// Role / aria-label
const role = node.getAttribute("role");
if (role) obj.role = role;
const ariaLabel = node.getAttribute("aria-label");
if (ariaLabel) obj.ariaLabel = ariaLabel;
// Children
if (children.length === 1 && typeof children[0] === "string") {
obj.text = children[0];
} else if (children.length > 0) {
obj.children = children;
}
return obj;
}
// Store refMap on window for later use by click/fill actions
const result = walk(document.body);
window.__cowRefMap = refMap;
return { tree: result, refCount: refCounter };
}
""" % (
str(list(_KEEP_TAGS)),
str(list(_INTERACTIVE_TAGS)),
)
def _should_use_headless() -> bool:
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
if sys.platform in ("win32", "darwin"):
return False
# Linux: check for display
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
return False
return True
def _flatten_tree(node, indent=0) -> List[str]:
"""Convert snapshot tree to compact text lines for LLM consumption."""
if node is None:
return []
if isinstance(node, str):
return [" " * indent + node]
if isinstance(node, list):
lines = []
for child in node:
lines.extend(_flatten_tree(child, indent))
return lines
if not isinstance(node, dict):
return []
tag = node.get("tag", "?")
ref = node.get("ref")
parts = [tag]
if ref:
parts[0] = f"[{ref}] {tag}"
# Inline attributes
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
val = node.get(attr)
if val:
# Truncate long values
s = str(val)
if len(s) > 80:
s = s[:77] + "..."
parts.append(f'{attr}="{s}"')
for flag in ("disabled", "checked", "selected"):
if node.get(flag):
parts.append(flag)
prefix = " " * indent
header = prefix + " ".join(parts)
text = node.get("text")
if text:
# Truncate long text
if len(text) > 120:
text = text[:117] + "..."
header += f": {text}"
lines = [header]
children = node.get("children", [])
for child in children:
lines.extend(_flatten_tree(child, indent + 2))
return lines
class BrowserService:
"""Manages a Playwright browser on a dedicated background thread.
All Playwright operations are dispatched to a single long-lived thread via
a task queue. Callers from *any* worker thread can use the public API
safely. An idle timer automatically shuts the browser down after
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
"""
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
def __init__(self, config: Optional[Dict[str, Any]] = None):
self._config = config or {}
self._headless: Optional[bool] = None
self._screenshot_dir: Optional[str] = None
# Background thread state
self._thread: Optional[threading.Thread] = None
self._task_queue: queue.Queue = queue.Queue()
self._lock = threading.Lock()
self._alive = False
self._ready = threading.Event()
# Playwright objects (only accessed on the background thread)
self._playwright = None
self._browser = None
self._context = None
self._page = None
# Idle auto-release
idle_cfg = self._config.get("idle_timeout")
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
self._idle_timer: Optional[threading.Timer] = None
# ------------------------------------------------------------------
# Background-thread lifecycle
# ------------------------------------------------------------------
def _start_thread(self):
"""Start the dedicated Playwright thread if not already running."""
with self._lock:
if self._alive and self._thread and self._thread.is_alive():
return
# Wait for old thread to fully exit before creating a new one
old = self._thread
if old and old.is_alive():
old.join(timeout=5)
# Fresh queue to avoid stale sentinels from a previous close()
self._task_queue = queue.Queue()
self._alive = True
self._ready = threading.Event()
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
self._thread.start()
# Block until browser is ready (or failed)
self._ready.wait(timeout=30)
def _run_loop(self):
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
logger.info("[Browser] Background thread started")
try:
self._launch_browser()
except Exception as e:
logger.error(f"[Browser] Failed to launch browser: {e}")
self._alive = False
self._ready.set()
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
return
self._ready.set()
while self._alive:
try:
task = self._task_queue.get(timeout=1.0)
except queue.Empty:
continue
if task is None:
break
fn, args, kwargs, result_slot = task
try:
result_slot["value"] = fn(*args, **kwargs)
except Exception as e:
result_slot["error"] = e
finally:
result_slot["event"].set()
self._shutdown_browser()
self._drain_queue(RuntimeError("Browser thread stopped"))
logger.info("[Browser] Background thread exited")
def _drain_queue(self, error: Exception):
"""Unblock all callers waiting on the queue with an error."""
while True:
try:
task = self._task_queue.get_nowait()
except queue.Empty:
break
if task is None:
continue
_, _, _, result_slot = task
result_slot["error"] = error
result_slot["event"].set()
def _launch_browser(self):
"""Launch Chromium on the background thread."""
if self._headless is None:
headless_cfg = self._config.get("headless")
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
launch_args = ["--disable-dev-shm-usage"]
if self._headless:
launch_args.append("--no-sandbox")
extra_args = self._config.get("launch_args", [])
if extra_args:
launch_args.extend(extra_args)
viewport_w = self._config.get("viewport_width", 1280)
viewport_h = self._config.get("viewport_height", 720)
self._playwright = sync_playwright().start()
logger.info(f"[Browser] Launching Chromium (headless={self._headless})")
self._browser = self._playwright.chromium.launch(
headless=self._headless,
args=launch_args,
)
self._context = self._browser.new_context(
viewport={"width": viewport_w, "height": viewport_h},
user_agent=(
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36"
),
)
self._page = self._context.new_page()
logger.info("[Browser] Browser ready")
def _shutdown_browser(self):
"""Shut down all Playwright resources on the background thread."""
self._cancel_idle_timer()
for obj, label in [
(self._context, "context"),
(self._browser, "browser"),
]:
try:
if obj:
obj.close()
except Exception as e:
logger.debug(f"[Browser] {label} close error: {e}")
try:
if self._playwright:
self._playwright.stop()
except Exception as e:
logger.debug(f"[Browser] playwright stop error: {e}")
self._page = None
self._context = None
self._browser = None
self._playwright = None
logger.info("[Browser] Browser closed")
def _submit(self, fn: Callable, *args, **kwargs):
"""Submit *fn* to the background thread and block until it completes."""
self._start_thread()
if not self._alive:
raise RuntimeError("Browser is not available")
self._reset_idle_timer()
result_slot: Dict[str, Any] = {"event": threading.Event()}
self._task_queue.put((fn, args, kwargs, result_slot))
# Timeout prevents permanent hang if the background thread crashes
completed = result_slot["event"].wait(timeout=120)
if not completed:
raise TimeoutError("Browser operation timed out (120s)")
if "error" in result_slot:
raise result_slot["error"]
return result_slot.get("value")
# ------------------------------------------------------------------
# Idle auto-release
# ------------------------------------------------------------------
def _reset_idle_timer(self):
self._cancel_idle_timer()
if self._idle_timeout > 0:
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
self._idle_timer.daemon = True
self._idle_timer.start()
def _cancel_idle_timer(self):
if self._idle_timer:
self._idle_timer.cancel()
self._idle_timer = None
def _on_idle_timeout(self):
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
self.close()
# ------------------------------------------------------------------
# Public lifecycle
# ------------------------------------------------------------------
def close(self):
"""Shut down browser and background thread (safe from any thread)."""
self._cancel_idle_timer()
with self._lock:
if not self._alive:
return
self._alive = False
t = self._thread
if self._task_queue is not None:
self._task_queue.put(None)
if t is not None and t.is_alive():
t.join(timeout=10)
with self._lock:
self._thread = None
# ------------------------------------------------------------------
# Actions (each method is dispatched to the background thread)
# ------------------------------------------------------------------
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
return self._submit(self._do_navigate, url, timeout)
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
page = self._page
try:
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
status = resp.status if resp else None
except Exception as e:
return {"error": f"Navigation failed: {e}"}
try:
page.wait_for_load_state("networkidle", timeout=8000)
except Exception:
pass
page.wait_for_timeout(500)
try:
title = page.title()
except Exception:
title = ""
try:
current_url = page.url
except Exception:
current_url = url
return {"url": current_url, "title": title, "status": status}
def snapshot(self, selector: Optional[str] = None) -> str:
return self._submit(self._do_snapshot, selector)
def _do_snapshot(self, selector: Optional[str] = None) -> str:
page = self._page
try:
result = page.evaluate(_SNAPSHOT_JS)
except Exception as e:
return f"[Snapshot error: {e}]"
tree = result.get("tree")
ref_count = result.get("refCount", 0)
lines = _flatten_tree(tree)
try:
title = page.title()
except Exception:
title = ""
try:
url = page.url
except Exception:
url = ""
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
body = "\n".join(lines)
max_chars = self._config.get("snapshot_max_chars", 30000)
if len(body) > max_chars:
body = body[:max_chars] + "\n... [snapshot truncated]"
return f"{header}\n{body}"
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
return self._submit(self._do_screenshot, full_page, cwd)
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
page = self._page
save_dir = self._get_screenshot_dir(cwd)
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
filepath = os.path.join(save_dir, filename)
page.screenshot(path=filepath, full_page=full_page)
logger.info(f"[Browser] Screenshot saved: {filepath}")
return filepath
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
timeout: int = 5000) -> Dict[str, Any]:
return self._submit(self._do_click, ref, selector, timeout)
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
page = self._page
try:
if ref is not None:
result = page.evaluate(f"""
() => {{
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
el.click();
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
}}
""")
if result.get("error"):
return result
page.wait_for_timeout(500)
return result
elif selector:
page.click(selector, timeout=timeout)
return {"clicked": True, "selector": selector}
else:
return {"error": "Provide either ref (from snapshot) or selector"}
except Exception as e:
return {"error": f"Click failed: {e}"}
def fill(self, text: str, ref: Optional[int] = None,
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
return self._submit(self._do_fill, text, ref, selector, timeout)
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
page = self._page
try:
if ref is not None:
result = page.evaluate(f"""
() => {{
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
el.focus();
el.value = "";
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
}}
""")
if result.get("error"):
return result
page.keyboard.type(text)
return {"filled": True, "ref": ref, "text": text}
elif selector:
page.fill(selector, text, timeout=timeout)
return {"filled": True, "selector": selector, "text": text}
else:
return {"error": "Provide either ref (from snapshot) or selector"}
except Exception as e:
return {"error": f"Fill failed: {e}"}
def select(self, value: str, ref: Optional[int] = None,
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
return self._submit(self._do_select, value, ref, selector, timeout)
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
page = self._page
try:
if ref is not None:
result = page.evaluate(f"""
() => {{
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
if (!el || el.tagName.toLowerCase() !== "select")
return {{ error: "ref {ref} is not a <select> element" }};
el.value = {repr(value)};
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
return {{ selected: true, value: el.value }};
}}
""")
return result
elif selector:
page.select_option(selector, value, timeout=timeout)
return {"selected": True, "selector": selector, "value": value}
else:
return {"error": "Provide either ref (from snapshot) or selector"}
except Exception as e:
return {"error": f"Select failed: {e}"}
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
return self._submit(self._do_scroll, direction, amount)
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
page = self._page
delta_map = {
"down": (0, amount),
"up": (0, -amount),
"right": (amount, 0),
"left": (-amount, 0),
}
dx, dy = delta_map.get(direction, (0, amount))
try:
page.mouse.wheel(dx, dy)
page.wait_for_timeout(300)
scroll_info = page.evaluate("""
() => ({
scrollX: window.scrollX,
scrollY: window.scrollY,
scrollHeight: document.documentElement.scrollHeight,
clientHeight: document.documentElement.clientHeight
})
""")
return {"scrolled": direction, "amount": amount, **scroll_info}
except Exception as e:
return {"error": f"Scroll failed: {e}"}
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
state: str = "visible") -> Dict[str, Any]:
return self._submit(self._do_wait, selector, timeout, state)
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
page = self._page
try:
if selector:
page.wait_for_selector(selector, timeout=timeout, state=state)
return {"waited": True, "selector": selector, "state": state}
else:
page.wait_for_timeout(timeout)
return {"waited": True, "timeout_ms": timeout}
except Exception as e:
return {"error": f"Wait failed: {e}"}
def go_back(self) -> Dict[str, Any]:
return self._submit(self._do_go_back)
def _do_go_back(self) -> Dict[str, Any]:
page = self._page
try:
page.go_back(wait_until="domcontentloaded", timeout=10000)
try:
title = page.title()
except Exception:
title = ""
try:
url = page.url
except Exception:
url = ""
return {"url": url, "title": title}
except Exception as e:
return {"error": f"Go back failed: {e}"}
def go_forward(self) -> Dict[str, Any]:
return self._submit(self._do_go_forward)
def _do_go_forward(self) -> Dict[str, Any]:
page = self._page
try:
page.go_forward(wait_until="domcontentloaded", timeout=10000)
try:
title = page.title()
except Exception:
title = ""
try:
url = page.url
except Exception:
url = ""
return {"url": url, "title": title}
except Exception as e:
return {"error": f"Go forward failed: {e}"}
def get_text(self, selector: str) -> Dict[str, Any]:
return self._submit(self._do_get_text, selector)
def _do_get_text(self, selector) -> Dict[str, Any]:
page = self._page
try:
text = page.text_content(selector, timeout=5000)
return {"text": text or ""}
except Exception as e:
return {"error": f"Get text failed: {e}"}
def evaluate(self, script: str) -> Dict[str, Any]:
return self._submit(self._do_evaluate, script)
def _do_evaluate(self, script) -> Dict[str, Any]:
page = self._page
try:
result = page.evaluate(script)
return {"result": result}
except Exception as e:
return {"error": f"Evaluate failed: {e}"}
def press(self, key: str) -> Dict[str, Any]:
return self._submit(self._do_press, key)
def _do_press(self, key) -> Dict[str, Any]:
page = self._page
try:
page.keyboard.press(key)
page.wait_for_timeout(300)
return {"pressed": key}
except Exception as e:
return {"error": f"Press failed: {e}"}
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_screenshot_dir(self, cwd: str = "") -> str:
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
return self._screenshot_dir
base = cwd or os.getcwd()
d = os.path.join(base, "tmp")
os.makedirs(d, exist_ok=True)
self._screenshot_dir = d
return d

View File

@@ -0,0 +1,290 @@
"""
Browser tool - Control a Chromium browser for web navigation and interaction.
Uses Playwright under the hood. Browser instance is lazily started on first
use, reused across tool calls within the same session, and cleaned up via
close().
"""
import json
import os
from typing import Dict, Any, Optional
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.browser.browser_service import BrowserService
from common.log import logger
class BrowserTool(BaseTool):
"""Single tool exposing all browser actions via an 'action' parameter."""
name: str = "browser"
description: str = (
"Control a browser to navigate web pages, interact with elements, and extract content. "
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
"get_text, press, evaluate.\n\n"
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help."
)
params: dict = {
"type": "object",
"properties": {
"action": {
"type": "string",
"description": (
"The browser action to perform. One of: "
"navigate, snapshot, click, fill, select, scroll, "
"screenshot, wait, back, forward, get_text, press, evaluate"
),
"enum": [
"navigate", "snapshot", "click", "fill", "select", "scroll",
"screenshot", "wait", "back", "forward", "get_text", "press",
"evaluate"
]
},
"url": {
"type": "string",
"description": "URL to navigate to (for 'navigate' action)"
},
"ref": {
"type": "integer",
"description": "Element ref number from snapshot (for click/fill/select)"
},
"selector": {
"type": "string",
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
},
"text": {
"type": "string",
"description": "Text to type (for 'fill' action)"
},
"value": {
"type": "string",
"description": "Option value (for 'select' action)"
},
"key": {
"type": "string",
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
},
"direction": {
"type": "string",
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
},
"script": {
"type": "string",
"description": "JavaScript code to execute (for 'evaluate' action)"
},
"full_page": {
"type": "boolean",
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
},
"timeout": {
"type": "integer",
"description": "Timeout in milliseconds (optional, default varies by action)"
}
},
"required": ["action"]
}
_shared_service: Optional[BrowserService] = None
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self._service: Optional[BrowserService] = None
def _get_service(self) -> BrowserService:
"""Get or create the browser service, sharing across copies."""
if self._service is not None:
return self._service
# Reuse shared service across tool copies within the same session
if BrowserTool._shared_service is not None:
self._service = BrowserTool._shared_service
return self._service
self._service = BrowserService(self.config)
BrowserTool._shared_service = self._service
return self._service
def execute(self, args: Dict[str, Any]) -> ToolResult:
action = args.get("action", "").strip().lower()
if not action:
return ToolResult.fail("Error: 'action' parameter is required")
handler = self._ACTION_MAP.get(action)
if not handler:
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
try:
return handler(self, args)
except Exception as e:
logger.error(f"[Browser] Action '{action}' error: {e}")
return ToolResult.fail(f"Browser error ({action}): {e}")
# ------------------------------------------------------------------
# Action handlers
# ------------------------------------------------------------------
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
url = args.get("url", "").strip()
if not url:
return ToolResult.fail("Error: 'url' is required for navigate action")
if not url.startswith(("http://", "https://")):
url = "https://" + url
timeout = args.get("timeout", 30000)
service = self._get_service()
result = service.navigate(url, timeout=timeout)
if "error" in result:
return ToolResult.fail(result["error"])
# Auto-snapshot after navigation so the agent gets page content in one call
snapshot_text = service.snapshot()
return ToolResult.success(
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
f"--- Page Snapshot ---\n{snapshot_text}"
)
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
selector = args.get("selector")
text = self._get_service().snapshot(selector=selector)
return ToolResult.success(text)
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
ref = args.get("ref")
selector = args.get("selector")
timeout = args.get("timeout", 5000)
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
text = args.get("text", "")
ref = args.get("ref")
selector = args.get("selector")
timeout = args.get("timeout", 5000)
if not text and text != "":
return ToolResult.fail("Error: 'text' is required for fill action")
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
value = args.get("value", "")
ref = args.get("ref")
selector = args.get("selector")
timeout = args.get("timeout", 5000)
if not value:
return ToolResult.fail("Error: 'value' is required for select action")
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Selected option '{value}'.")
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
direction = args.get("direction", "down")
amount = args.get("timeout", 500) # reuse timeout field or default
if "amount" in args:
amount = args["amount"]
result = self._get_service().scroll(direction=direction, amount=amount)
if "error" in result:
return ToolResult.fail(result["error"])
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
full_page = args.get("full_page", False)
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
return ToolResult.success(f"Screenshot saved to: {filepath}")
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
selector = args.get("selector")
timeout = args.get("timeout", 5000)
result = self._get_service().wait(selector=selector, timeout=timeout)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Wait completed.")
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
result = self._get_service().go_back()
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Navigated back to: {result['url']}")
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
result = self._get_service().go_forward()
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Navigated forward to: {result['url']}")
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
selector = args.get("selector", "").strip()
if not selector:
return ToolResult.fail("Error: 'selector' is required for get_text action")
result = self._get_service().get_text(selector)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(result["text"])
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
key = args.get("key", "").strip()
if not key:
return ToolResult.fail("Error: 'key' is required for press action")
result = self._get_service().press(key)
if "error" in result:
return ToolResult.fail(result["error"])
return ToolResult.success(f"Pressed key: {key}")
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
script = args.get("script", "").strip()
if not script:
return ToolResult.fail("Error: 'script' is required for evaluate action")
result = self._get_service().evaluate(script)
if "error" in result:
return ToolResult.fail(result["error"])
val = result.get("result")
if isinstance(val, (dict, list)):
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
return ToolResult.success(str(val) if val is not None else "(no return value)")
# Action dispatch table
_ACTION_MAP = {
"navigate": _do_navigate,
"snapshot": _do_snapshot,
"click": _do_click,
"fill": _do_fill,
"select": _do_select,
"scroll": _do_scroll,
"screenshot": _do_screenshot,
"wait": _do_wait,
"back": _do_back,
"forward": _do_forward,
"get_text": _do_get_text,
"press": _do_press,
"evaluate": _do_evaluate,
}
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def copy(self):
"""Share browser instance across tool copies (avoids re-launching)."""
new_tool = BrowserTool(self.config)
new_tool.model = self.model
new_tool.context = getattr(self, "context", None)
new_tool.cwd = self.cwd
new_tool._service = self._service
return new_tool
def close(self):
"""Release browser resources."""
if self._service:
self._service.close()
self._service = None
BrowserTool._shared_service = None
logger.info("[Browser] BrowserTool closed")

View File

@@ -0,0 +1,3 @@
from .edit import Edit
__all__ = ['Edit']

185
agent/tools/edit/edit.py Normal file
View File

@@ -0,0 +1,185 @@
"""
Edit tool - Precise file editing
Edit files through exact text replacement
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
from agent.tools.utils.diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string
)
class Edit(BaseTool):
"""Tool for precise file editing"""
name: str = "edit"
description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to edit (relative or absolute)"
},
"oldText": {
"type": "string",
"description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
},
"newText": {
"type": "string",
"description": "New text to replace the old text with"
}
},
"required": ["path", "oldText", "newText"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.memory_manager = self.config.get("memory_manager", None)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file edit operation
:param args: Contains file path, old text and new text
:return: Operation result
"""
path = args.get("path", "").strip()
old_text = args.get("oldText", "")
new_text = args.get("newText", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable/writable
if not os.access(absolute_path, os.R_OK | os.W_OK):
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
try:
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
raw_content = f.read()
# Remove BOM (LLM won't include invisible BOM in oldText)
bom, content = strip_bom(raw_content)
# Detect original line ending
original_ending = detect_line_ending(content)
# Normalize to LF
normalized_content = normalize_to_lf(content)
normalized_old_text = normalize_to_lf(old_text)
normalized_new_text = normalize_to_lf(new_text)
# Special case: empty oldText means append to end of file
if not old_text or not old_text.strip():
# Append mode: add newText to the end
# Add newline before newText if file doesn't end with one
if normalized_content and not normalized_content.endswith('\n'):
new_content = normalized_content + '\n' + normalized_new_text
else:
new_content = normalized_content + normalized_new_text
base_content = normalized_content # For verification
else:
# Normal edit mode: find and replace
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
if not match_result.found:
return ToolResult.fail(
f"Error: Could not find the exact text in {path}. "
"The old text must match exactly including all whitespace and newlines."
)
# Calculate occurrence count (use fuzzy normalized content for consistency)
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
occurrences = fuzzy_content.count(fuzzy_old_text)
if occurrences > 1:
return ToolResult.fail(
f"Error: Found {occurrences} occurrences of the text in {path}. "
"The text must be unique. Please provide more context to make it unique."
)
# Execute replacement (use matched text position)
base_content = match_result.content_for_replacement
new_content = (
base_content[:match_result.index] +
normalized_new_text +
base_content[match_result.index + match_result.match_length:]
)
# Verify replacement actually changed content
if base_content == new_content:
return ToolResult.fail(
f"Error: No changes made to {path}. "
"The replacement produced identical content. "
"This might indicate an issue with special characters or the text not existing as expected."
)
# Restore original line endings
final_content = bom + restore_line_endings(new_content, original_ending)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(final_content)
# Generate diff
diff_result = generate_diff_string(base_content, new_content)
result = {
"message": f"Successfully replaced text in {path}",
"path": path,
"diff": diff_result['diff'],
"first_changed_line": diff_result['first_changed_line']
}
# Notify memory manager if file is in memory directory
if self.memory_manager and "memory/" in path:
try:
self.memory_manager.mark_dirty()
except Exception as e:
# Don't fail the edit if memory notification fails
pass
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
except PermissionError:
return ToolResult.fail(f"Error: Permission denied accessing {path}")
except Exception as e:
return ToolResult.fail(f"Error editing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,3 @@
from agent.tools.env_config.env_config import EnvConfig
__all__ = ['EnvConfig']

View File

@@ -0,0 +1,286 @@
"""
Environment Configuration Tool - Manage API keys and environment variables
"""
import os
import re
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from common.utils import expand_path
# API Key 知识库:常见的环境变量及其描述
API_KEY_REGISTRY = {
# AI 模型服务
"OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
"GEMINI_API_KEY": "Google Gemini API 密钥",
"CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
"LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
# 搜索服务
"BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
}
class EnvConfig(BaseTool):
"""Tool for managing environment variables (API keys, etc.)"""
name: str = "env_config"
description: str = (
"Manage API keys and skill configurations securely. "
"Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
"view configured keys, or manage skill settings. "
"Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
"Values are automatically masked for security. Changes take effect immediately via hot reload."
)
params: dict = {
"type": "object",
"properties": {
"action": {
"type": "string",
"description": "Action to perform: 'set', 'get', 'list', 'delete'",
"enum": ["set", "get", "list", "delete"]
},
"key": {
"type": "string",
"description": (
"Environment variable key name. Common keys:\n"
"- OPENAI_API_KEY: OpenAI API (GPT models)\n"
"- OPENAI_API_BASE: OpenAI API base URL\n"
"- CLAUDE_API_KEY: Anthropic Claude API\n"
"- GEMINI_API_KEY: Google Gemini API\n"
"- LINKAI_API_KEY: LinkAI platform\n"
"- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
"Use exact key names (case-sensitive, all uppercase with underscores)"
)
},
"value": {
"type": "string",
"description": "Value to set for the environment variable (for 'set' action)"
}
},
"required": ["action"]
}
def __init__(self, config: dict = None):
self.config = config or {}
# Store env config in ~/.cow directory (outside workspace for security)
self.env_dir = expand_path("~/.cow")
self.env_path = os.path.join(self.env_dir, '.env')
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
# Don't create .env file in __init__ to avoid issues during tool discovery
# It will be created on first use in execute()
def _ensure_env_file(self):
"""Ensure the .env file exists"""
# Create ~/.cow directory if it doesn't exist
os.makedirs(self.env_dir, exist_ok=True)
if not os.path.exists(self.env_path):
Path(self.env_path).touch()
logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
def _mask_value(self, value: str) -> str:
"""Mask sensitive parts of a value for logging"""
if not value or len(value) <= 10:
return "***"
return f"{value[:6]}***{value[-4:]}"
def _read_env_file(self) -> Dict[str, str]:
"""Read all key-value pairs from .env file"""
env_vars = {}
if os.path.exists(self.env_path):
with open(self.env_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
continue
# Parse KEY=VALUE
match = re.match(r'^([^=]+)=(.*)$', line)
if match:
key, value = match.groups()
env_vars[key.strip()] = value.strip()
return env_vars
def _write_env_file(self, env_vars: Dict[str, str]):
"""Write all key-value pairs to .env file"""
with open(self.env_path, 'w', encoding='utf-8') as f:
f.write("# Environment variables for agent skills\n")
f.write("# Auto-managed by env_config tool\n\n")
for key, value in sorted(env_vars.items()):
f.write(f"{key}={value}\n")
def _reload_env(self):
"""Reload environment variables from .env file"""
env_vars = self._read_env_file()
for key, value in env_vars.items():
os.environ[key] = value
logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
def _refresh_skills(self):
"""Refresh skills after environment variable changes"""
if self.agent_bridge:
try:
# Reload .env file
self._reload_env()
# Refresh skills in all agent instances
refreshed = self.agent_bridge.refresh_all_skills()
logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
return True
except Exception as e:
logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
return False
return False
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute environment configuration operation
:param args: Contains action, key, and value parameters
:return: Result of the operation
"""
# Ensure .env file exists on first use
self._ensure_env_file()
action = args.get("action")
key = args.get("key")
value = args.get("value")
try:
if action == "set":
if not key or not value:
return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
# Read current env vars
env_vars = self._read_env_file()
# Update the key
env_vars[key] = value
# Write back to file
self._write_env_file(env_vars)
# Update current process env
os.environ[key] = value
logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
# Try to refresh skills immediately
refreshed = self._refresh_skills()
result = {
"message": f"Successfully set {key}",
"key": key,
"value": self._mask_value(value),
}
if refreshed:
result["note"] = "✅ Skills refreshed automatically - changes are now active"
else:
result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
return ToolResult.success(result)
elif action == "get":
if not key:
return ToolResult.fail("Error: 'key' is required for 'get' action.")
# Check in file first, then in current env
env_vars = self._read_env_file()
value = env_vars.get(key) or os.getenv(key)
# Get description from registry
description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
if value is not None:
logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
return ToolResult.success({
"key": key,
"value": self._mask_value(value),
"description": description,
"exists": True,
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
})
else:
return ToolResult.success({
"key": key,
"description": description,
"exists": False,
"message": f"Environment variable '{key}' is not set"
})
elif action == "list":
env_vars = self._read_env_file()
# Build detailed variable list with descriptions
variables_with_info = {}
for key, value in env_vars.items():
variables_with_info[key] = {
"value": self._mask_value(value),
"description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
}
logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
if not env_vars:
return ToolResult.success({
"message": "No environment variables configured",
"variables": {},
"note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
})
return ToolResult.success({
"message": f"Found {len(env_vars)} environment variable(s)",
"variables": variables_with_info
})
elif action == "delete":
if not key:
return ToolResult.fail("Error: 'key' is required for 'delete' action.")
# Read current env vars
env_vars = self._read_env_file()
if key not in env_vars:
return ToolResult.success({
"message": f"Environment variable '{key}' was not set",
"key": key
})
# Remove the key
del env_vars[key]
# Write back to file
self._write_env_file(env_vars)
# Remove from current process env
if key in os.environ:
del os.environ[key]
logger.info(f"[EnvConfig] Deleted {key}")
# Try to refresh skills immediately
refreshed = self._refresh_skills()
result = {
"message": f"Successfully deleted {key}",
"key": key,
}
if refreshed:
result["note"] = "✅ Skills refreshed automatically - changes are now active"
else:
result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
return ToolResult.success(result)
else:
return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
except Exception as e:
logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
return ToolResult.fail(f"EnvConfig tool error: {str(e)}")

View File

@@ -0,0 +1,3 @@
from .ls import Ls
__all__ = ['Ls']

140
agent/tools/ls/ls.py Normal file
View File

@@ -0,0 +1,140 @@
"""
Ls tool - List directory contents
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
from common.utils import expand_path
DEFAULT_LIMIT = 500
class Ls(BaseTool):
"""Tool for listing directory contents"""
name: str = "ls"
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Directory to list. IMPORTANT: Relative paths are based on workspace directory. To access directories outside workspace, use absolute paths starting with ~ or /."
},
"limit": {
"type": "integer",
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
}
},
"required": []
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute directory listing
:param args: Listing parameters
:return: Directory contents or error
"""
path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
# Resolve path
absolute_path = self._resolve_path(path)
# Security check: Prevent accessing sensitive config directory
env_config_dir = expand_path("~/.cow")
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
return ToolResult.fail(
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
)
if not os.path.exists(absolute_path):
# Provide helpful hint if using relative path
if not os.path.isabs(path) and not path.startswith('~'):
return ToolResult.fail(
f"Error: Path not found: {path}\n"
f"Resolved to: {absolute_path}\n"
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
)
return ToolResult.fail(f"Error: Path not found: {path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {path}")
try:
# Read directory entries
entries = os.listdir(absolute_path)
# Sort alphabetically (case-insensitive)
entries.sort(key=lambda x: x.lower())
# Format entries with directory indicators
results = []
entry_limit_reached = False
for entry in entries:
if len(results) >= limit:
entry_limit_reached = True
break
full_path = os.path.join(absolute_path, entry)
try:
if os.path.isdir(full_path):
results.append(entry + '/')
else:
results.append(entry)
except Exception:
# Skip entries we can't stat
continue
if not results:
return ToolResult.success({"message": "(empty directory)", "entries": []})
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if entry_limit_reached:
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
details["entry_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"entry_count": len(results),
"details": details if details else None
})
except PermissionError:
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
except Exception as e:
return ToolResult.fail(f"Error listing directory: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,10 @@
"""
Memory tools for Agent
Provides memory_search and memory_get tools
"""
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
__all__ = ['MemorySearchTool', 'MemoryGetTool']

View File

@@ -0,0 +1,111 @@
"""
Memory get tool
Allows agents to read specific sections from memory files
"""
from agent.tools.base_tool import BaseTool
class MemoryGetTool(BaseTool):
"""Tool for reading memory file contents"""
name: str = "memory_get"
description: str = (
"Read specific content from memory files. "
"Use this to get full context from a memory file or specific line range."
)
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
},
"start_line": {
"type": "integer",
"description": "Starting line number (optional, default: 1)",
"default": 1
},
"num_lines": {
"type": "integer",
"description": "Number of lines to read (optional, reads all if not specified)"
}
},
"required": ["path"]
}
def __init__(self, memory_manager):
"""
Initialize memory get tool
Args:
memory_manager: MemoryManager instance
"""
super().__init__()
self.memory_manager = memory_manager
def execute(self, args: dict):
"""
Execute memory file read
Args:
args: Dictionary with path, start_line, num_lines
Returns:
ToolResult with file content
"""
from agent.tools.base_tool import ToolResult
path = args.get("path")
start_line = args.get("start_line", 1)
num_lines = args.get("num_lines")
if not path:
return ToolResult.fail("Error: path parameter is required")
try:
workspace_dir = self.memory_manager.config.get_workspace()
# Auto-prepend memory/ if not present and not absolute path
# Exception: MEMORY.md is in the root directory
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
path = f'memory/{path}'
file_path = workspace_dir / path
if not file_path.exists():
return ToolResult.fail(f"Error: File not found: {path}")
content = file_path.read_text(encoding='utf-8')
lines = content.split('\n')
# Handle line range
if start_line < 1:
start_line = 1
start_idx = start_line - 1
if num_lines:
end_idx = start_idx + num_lines
selected_lines = lines[start_idx:end_idx]
else:
selected_lines = lines[start_idx:]
result = '\n'.join(selected_lines)
# Add metadata
total_lines = len(lines)
shown_lines = len(selected_lines)
output = [
f"File: {path}",
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
"",
result
]
return ToolResult.success('\n'.join(output))
except Exception as e:
return ToolResult.fail(f"Error reading memory file: {str(e)}")

View File

@@ -0,0 +1,102 @@
"""
Memory search tool
Allows agents to search their memory using semantic and keyword search
"""
from typing import Dict, Any, Optional
from agent.tools.base_tool import BaseTool
class MemorySearchTool(BaseTool):
"""Tool for searching agent memory"""
name: str = "memory_search"
description: str = (
"Search agent's long-term memory using semantic and keyword search. "
"Use this to recall past conversations, preferences, and knowledge."
)
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (can be natural language question or keywords)"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)",
"default": 10
},
"min_score": {
"type": "number",
"description": "Minimum relevance score (0-1, default: 0.1)",
"default": 0.1
}
},
"required": ["query"]
}
def __init__(self, memory_manager, user_id: Optional[str] = None):
"""
Initialize memory search tool
Args:
memory_manager: MemoryManager instance
user_id: Optional user ID for scoped search
"""
super().__init__()
self.memory_manager = memory_manager
self.user_id = user_id
def execute(self, args: dict):
"""
Execute memory search
Args:
args: Dictionary with query, max_results, min_score
Returns:
ToolResult with formatted search results
"""
from agent.tools.base_tool import ToolResult
import asyncio
query = args.get("query")
max_results = args.get("max_results", 10)
min_score = args.get("min_score", 0.1)
if not query:
return ToolResult.fail("Error: query parameter is required")
try:
# Run async search in sync context
results = asyncio.run(self.memory_manager.search(
query=query,
user_id=self.user_id,
max_results=max_results,
min_score=min_score,
include_shared=True
))
if not results:
# Return clear message that no memories exist yet
# This prevents infinite retry loops
return ToolResult.success(
f"No memories found for '{query}'. "
f"This is normal if no memories have been stored yet. "
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
)
# Format results
output = [f"Found {len(results)} relevant memories:\n"]
for i, result in enumerate(results, 1):
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
output.append(f" Score: {result.score:.3f}")
output.append(f" Snippet: {result.snippet}")
return ToolResult.success("\n".join(output))
except Exception as e:
return ToolResult.fail(f"Error searching memory: {str(e)}")

View File

@@ -0,0 +1,3 @@
from .read import Read
__all__ = ['Read']

557
agent/tools/read/read.py Normal file
View File

@@ -0,0 +1,557 @@
"""
Read tool - Read file contents
Supports text files, images (jpg, png, gif, webp), and PDF files
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
from common.utils import expand_path
class Read(BaseTool):
"""Tool for reading file contents"""
name: str = "read"
description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read. IMPORTANT: Relative paths are based on workspace directory. To access files outside workspace, use absolute paths starting with ~ or /."
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (optional)"
}
},
"required": ["path"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# File type categories
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
self.pdf_extensions = {'.pdf'}
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
# Readable text formats (will be read with truncation)
self.text_extensions = {
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
'.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
'.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
}
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file read operation
:param args: Contains file path and optional offset/limit parameters
:return: File content or error message
"""
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
path = args.get("path", "") or args.get("location", "")
path = path.strip() if isinstance(path, str) else ""
offset = args.get("offset")
limit = args.get("limit")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Security check: Prevent reading sensitive config files
env_config_path = expand_path("~/.cow/.env")
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
return ToolResult.fail(
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
)
# Check if file exists
if not os.path.exists(absolute_path):
# Provide helpful hint if using relative path
if not os.path.isabs(path) and not path.startswith('~'):
return ToolResult.fail(
f"Error: File not found: {path}\n"
f"Resolved to: {absolute_path}\n"
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
)
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable
if not os.access(absolute_path, os.R_OK):
return ToolResult.fail(f"Error: File is not readable: {path}")
# Check file type
file_ext = Path(absolute_path).suffix.lower()
file_size = os.path.getsize(absolute_path)
# Check if image - return metadata for sending
if file_ext in self.image_extensions:
return self._read_image(absolute_path, file_ext)
# Check if video/audio/binary/archive - return metadata only
if file_ext in self.video_extensions:
return self._return_file_metadata(absolute_path, "video", file_size)
if file_ext in self.audio_extensions:
return self._return_file_metadata(absolute_path, "audio", file_size)
if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
return self._return_file_metadata(absolute_path, "binary", file_size)
# Check if PDF
if file_ext in self.pdf_extensions:
return self._read_pdf(absolute_path, path, offset, limit)
# Check if Office document (.docx, .xlsx, .pptx, etc.)
if file_ext in self.office_extensions:
return self._read_office(absolute_path, path, file_ext, offset, limit)
# Read text file (with truncation for large files)
return self._read_text(absolute_path, path, offset, limit)
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
"""
Return file metadata for non-readable files (video, audio, binary, etc.)
:param absolute_path: Absolute path to the file
:param file_type: Type of file (video, audio, binary, etc.)
:param file_size: File size in bytes
:return: File metadata
"""
file_name = Path(absolute_path).name
file_ext = Path(absolute_path).suffix.lower()
# Determine MIME type
mime_types = {
# Video
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
'.mkv': 'video/x-matroska', '.webm': 'video/webm',
# Audio
'.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
'.m4a': 'audio/mp4', '.flac': 'audio/flac',
# Binary
'.zip': 'application/zip', '.tar': 'application/x-tar',
'.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
}
mime_type = mime_types.get(file_ext, 'application/octet-stream')
result = {
"type": f"{file_type}_metadata",
"file_type": file_type,
"path": absolute_path,
"file_name": file_name,
"mime_type": mime_type,
"size": file_size,
"size_formatted": format_size(file_size),
"message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
}
return ToolResult.success(result)
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
"""
Read image file - always return metadata only (images should be sent, not read into context)
:param absolute_path: Absolute path to the image file
:param file_ext: File extension
:return: Result containing image metadata for sending
"""
try:
# Get file size
file_size = os.path.getsize(absolute_path)
# Determine MIME type
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
# Return metadata for images (NOT file_to_send - use send tool to actually send)
result = {
"type": "image_metadata",
"file_type": "image",
"path": absolute_path,
"mime_type": mime_type,
"size": file_size,
"size_formatted": format_size(file_size),
"message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
}
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading image file: {str(e)}")
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read text file
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: File content or error message
"""
try:
# Check file size first
file_size = os.path.getsize(absolute_path)
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
if file_size > MAX_FILE_SIZE:
# File too large, return metadata only
return ToolResult.success({
"type": "file_to_send",
"file_type": "document",
"path": absolute_path,
"size": file_size,
"size_formatted": format_size(file_size),
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
})
# Read file (utf-8-sig strips BOM automatically on Windows)
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
content = f.read()
# Truncate content if too long (20K characters max for model context)
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
content_truncated = False
if len(content) > MAX_CONTENT_CHARS:
content = content[:MAX_CONTENT_CHARS]
content_truncated = True
all_lines = content.split('\n')
total_file_lines = len(all_lines)
# Apply offset (if specified)
start_line = 0
if offset is not None:
if offset < 0:
# Negative offset: read from end
# -20 means "last 20 lines" → start from (total - 20)
start_line = max(0, total_file_lines + offset)
else:
# Positive offset: read from start (1-indexed)
start_line = max(0, offset - 1) # Convert to 0-indexed
if start_line >= total_file_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
)
start_line_display = start_line + 1 # For display (1-indexed)
# If user specified limit, use it
selected_content = content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_file_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation (considering line count and byte limits)
truncation = truncate_head(selected_content)
output_text = ""
details = {}
# Add truncation warning if content was truncated
if content_truncated:
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
if truncation.first_line_exceeds_limit:
# First line exceeds 30KB limit
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
details["truncation"] = truncation.to_dict()
elif truncation.truncated:
# Truncation occurred
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
# User specified limit, more content available, but no truncation
remaining = total_file_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
# No truncation, no exceeding user limit
output_text = truncation.content
result = {
"content": output_text,
"total_lines": total_file_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
except Exception as e:
return ToolResult.fail(f"Error reading file: {str(e)}")
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
offset: int = None, limit: int = None) -> ToolResult:
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
try:
text = self._extract_office_text(absolute_path, file_ext)
except ImportError as e:
return ToolResult.fail(str(e))
except Exception as e:
return ToolResult.fail(f"Error reading Office document: {e}")
if not text or not text.strip():
return ToolResult.success({
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
})
all_lines = text.split('\n')
total_lines = len(all_lines)
start_line = 0
if offset is not None:
if offset < 0:
start_line = max(0, total_lines + offset)
else:
start_line = max(0, offset - 1)
if start_line >= total_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
)
selected_content = text
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
truncation = truncate_head(selected_content)
start_line_display = start_line + 1
output_text = ""
if truncation.truncated:
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
remaining = total_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
output_text = truncation.content
return ToolResult.success({
"content": output_text,
"total_lines": total_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines,
})
@staticmethod
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
"""Extract plain text from an Office document."""
if file_ext in ('.docx', '.doc'):
try:
from docx import Document
except ImportError:
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
doc = Document(absolute_path)
paragraphs = [p.text for p in doc.paragraphs]
for table in doc.tables:
for row in table.rows:
paragraphs.append('\t'.join(cell.text for cell in row.cells))
return '\n'.join(paragraphs)
if file_ext in ('.xlsx', '.xls'):
try:
from openpyxl import load_workbook
except ImportError:
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
wb = load_workbook(absolute_path, read_only=True, data_only=True)
parts = []
for ws in wb.worksheets:
parts.append(f"--- Sheet: {ws.title} ---")
for row in ws.iter_rows(values_only=True):
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
wb.close()
return '\n'.join(parts)
if file_ext in ('.pptx', '.ppt'):
try:
from pptx import Presentation
except ImportError:
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
prs = Presentation(absolute_path)
parts = []
for i, slide in enumerate(prs.slides, 1):
parts.append(f"--- Slide {i} ---")
for shape in slide.shapes:
if shape.has_text_frame:
for para in shape.text_frame.paragraphs:
text = para.text.strip()
if text:
parts.append(text)
return '\n'.join(parts)
return ""
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read PDF file content
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: PDF text content or error message
"""
try:
# Try to import pypdf
try:
from pypdf import PdfReader
except ImportError:
return ToolResult.fail(
"Error: pypdf library not installed. Install with: pip install pypdf"
)
# Read PDF
reader = PdfReader(absolute_path)
total_pages = len(reader.pages)
# Extract text from all pages
text_parts = []
for page_num, page in enumerate(reader.pages, 1):
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
if not text_parts:
return ToolResult.success({
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
"total_pages": total_pages,
"message": "PDF may contain only images or be encrypted"
})
# Merge all text
full_content = "\n\n".join(text_parts)
all_lines = full_content.split('\n')
total_lines = len(all_lines)
# Apply offset and limit (same logic as text files)
start_line = 0
if offset is not None:
start_line = max(0, offset - 1)
if start_line >= total_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
)
start_line_display = start_line + 1
selected_content = full_content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation
truncation = truncate_head(selected_content)
output_text = ""
details = {}
if truncation.truncated:
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
remaining = total_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
output_text = truncation.content
result = {
"content": output_text,
"total_pages": total_pages,
"total_lines": total_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading PDF file: {str(e)}")

View File

@@ -0,0 +1,287 @@
# 定时任务工具 (Scheduler Tool)
## 功能简介
定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
-**定时提醒**: 在指定时间发送消息
- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
- 📋 **任务管理**: 查询、启用、禁用、删除任务
## 安装依赖
```bash
pip install croniter>=2.0.0
```
## 使用方法
### 1. 创建定时任务
Agent 可以通过自然语言创建定时任务,支持两种类型:
#### 1.1 静态消息任务
发送预定义的消息:
**示例对话:**
```
用户: 每天早上9点提醒我开会
Agent: [调用 scheduler 工具]
action: create
name: 每日开会提醒
message: 该开会了!
schedule_type: cron
schedule_value: 0 9 * * *
```
#### 1.2 动态工具调用任务
定时执行工具并发送结果:
**示例对话:**
```
用户: 每天早上8点帮我读取一下今日日程
Agent: [调用 scheduler 工具]
action: create
name: 每日日程
tool_call:
tool_name: read
tool_params:
file_path: ~/cow/schedule.txt
result_prefix: 📅 今日日程
schedule_type: cron
schedule_value: 0 8 * * *
```
**工具调用参数说明:**
- `tool_name`: 要调用的工具名称(如 `bash``read``write` 等内置工具)
- `tool_params`: 工具的参数(字典格式)
- `result_prefix`: 可选,在结果前添加的前缀文本
**注意:** 如果要使用 skills如 bocha-search需要通过 `bash` 工具调用 skill 脚本
### 2. 支持的调度类型
#### Cron 表达式 (`cron`)
使用标准 cron 表达式:
```
0 9 * * * # 每天 9:00
0 */2 * * * # 每 2 小时
30 8 * * 1-5 # 工作日 8:30
0 0 1 * * # 每月 1 号
```
#### 固定间隔 (`interval`)
以秒为单位的间隔:
```
3600 # 每小时
86400 # 每天
1800 # 每 30 分钟
```
#### 一次性任务 (`once`)
指定具体时间ISO 格式):
```
2024-12-25T09:00:00
2024-12-31T23:59:59
```
### 3. 查询任务列表
```
用户: 查看我的定时任务
Agent: [调用 scheduler 工具]
action: list
```
### 4. 查看任务详情
```
用户: 查看任务 abc123 的详情
Agent: [调用 scheduler 工具]
action: get
task_id: abc123
```
### 5. 删除任务
```
用户: 删除任务 abc123
Agent: [调用 scheduler 工具]
action: delete
task_id: abc123
```
### 6. 启用/禁用任务
```
用户: 暂停任务 abc123
Agent: [调用 scheduler 工具]
action: disable
task_id: abc123
用户: 恢复任务 abc123
Agent: [调用 scheduler 工具]
action: enable
task_id: abc123
```
## 任务存储
任务保存在 JSON 文件中:
```
~/cow/scheduler/tasks.json
```
任务数据结构:
**静态消息任务:**
```json
{
"id": "abc123",
"name": "每日提醒",
"enabled": true,
"created_at": "2024-01-01T10:00:00",
"updated_at": "2024-01-01T10:00:00",
"schedule": {
"type": "cron",
"expression": "0 9 * * *"
},
"action": {
"type": "send_message",
"content": "该开会了!",
"receiver": "wxid_xxx",
"receiver_name": "张三",
"is_group": false,
"channel_type": "wechat"
},
"next_run_at": "2024-01-02T09:00:00",
"last_run_at": "2024-01-01T09:00:00"
}
```
**动态工具调用任务:**
```json
{
"id": "def456",
"name": "每日日程",
"enabled": true,
"created_at": "2024-01-01T10:00:00",
"updated_at": "2024-01-01T10:00:00",
"schedule": {
"type": "cron",
"expression": "0 8 * * *"
},
"action": {
"type": "tool_call",
"tool_name": "read",
"tool_params": {
"file_path": "~/cow/schedule.txt"
},
"result_prefix": "📅 今日日程",
"receiver": "wxid_xxx",
"receiver_name": "张三",
"is_group": false,
"channel_type": "wechat"
},
"next_run_at": "2024-01-02T08:00:00"
}
```
## 后台服务
定时任务由后台服务 `SchedulerService` 管理:
- 每 30 秒检查一次到期任务
- 自动执行到期任务
- 计算下次执行时间
- 记录执行历史和错误
服务在 Agent 初始化时自动启动,无需手动配置。
## 接收者确定
定时任务会发送给**创建任务时的对话对象**
- 如果在私聊中创建,发送给该用户
- 如果在群聊中创建,发送到该群
- 接收者信息在创建时自动保存
## 常见用例
### 1. 每日提醒(静态消息)
```
用户: 每天早上8点提醒我吃药
Agent: ✅ 定时任务创建成功
任务ID: a1b2c3d4
调度: 每天 8:00
消息: 该吃药了!
```
### 2. 工作日提醒(静态消息)
```
用户: 工作日下午6点提醒我下班
Agent: [创建 cron: 0 18 * * 1-5]
消息: 该下班了!
```
### 3. 倒计时提醒(静态消息)
```
用户: 1小时后提醒我
Agent: [创建 interval: 3600]
```
### 4. 每日日程推送(动态工具调用)
```
用户: 每天早上8点帮我读取今日日程
Agent: ✅ 定时任务创建成功
任务ID: schedule001
调度: 每天 8:00
工具: read(file_path='~/cow/schedule.txt')
前缀: 📅 今日日程
```
### 5. 定时文件备份(动态工具调用)
```
用户: 每天晚上11点备份工作文件
Agent: [创建 cron: 0 23 * * *]
工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
前缀: ✅ 文件已备份
```
### 6. 周报提醒(静态消息)
```
用户: 每周五下午5点提醒我写周报
Agent: [创建 cron: 0 17 * * 5]
消息: 📊 该写周报了!
```
### 4. 特定日期提醒
```
用户: 12月25日早上9点提醒我圣诞快乐
Agent: [创建 once: 2024-12-25T09:00:00]
```
## 注意事项
1. **时区**: 使用系统本地时区
2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
3. **持久化**: 任务保存在文件中,重启后自动恢复
4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
5. **错误处理**: 执行失败会记录错误,不影响其他任务
## 技术实现
- **TaskStore**: 任务持久化存储
- **SchedulerService**: 后台调度服务
- **SchedulerTool**: Agent 工具接口
- **Integration**: 与 AgentBridge 集成
## 依赖
- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB

View File

@@ -0,0 +1,7 @@
"""
Scheduler tool for managing scheduled tasks
"""
from .scheduler_tool import SchedulerTool
__all__ = ["SchedulerTool"]

View File

@@ -0,0 +1,464 @@
"""
Integration module for scheduler with AgentBridge
"""
import os
from typing import Optional
from config import conf
from common.log import logger
from common.utils import expand_path
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
# Global scheduler service instance
_scheduler_service = None
_task_store = None
def init_scheduler(agent_bridge) -> bool:
"""
Initialize scheduler service
Args:
agent_bridge: AgentBridge instance
Returns:
True if initialized successfully
"""
global _scheduler_service, _task_store
try:
from agent.tools.scheduler.task_store import TaskStore
from agent.tools.scheduler.scheduler_service import SchedulerService
# Get workspace from config
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
# Create task store
_task_store = TaskStore(store_path)
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
# Create execute callback
def execute_task_callback(task: dict):
"""Callback to execute a scheduled task"""
try:
action = task.get("action", {})
action_type = action.get("type")
if action_type == "agent_task":
_execute_agent_task(task, agent_bridge)
elif action_type == "send_message":
# Legacy support for old tasks
_execute_send_message(task, agent_bridge)
elif action_type == "tool_call":
# Legacy support for old tasks
_execute_tool_call(task, agent_bridge)
elif action_type == "skill_call":
# Legacy support for old tasks
_execute_skill_call(task, agent_bridge)
else:
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
except Exception as e:
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
# Create scheduler service
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
_scheduler_service.start()
logger.debug("[Scheduler] Scheduler service initialized and started")
return True
except Exception as e:
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
return False
def get_task_store():
"""Get the global task store instance"""
return _task_store
def get_scheduler_service():
"""Get the global scheduler service instance"""
return _scheduler_service
def _execute_agent_task(task: dict, agent_bridge):
"""
Execute an agent_task action - let Agent handle the task
Args:
task: Task dictionary
agent_bridge: AgentBridge instance
"""
try:
action = task.get("action", {})
task_description = action.get("task_description")
receiver = action.get("receiver")
is_group = action.get("is_group", False)
channel_type = action.get("channel_type", "unknown")
if not task_description:
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
return
if not receiver:
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
return
# Check for unsupported channels
if channel_type == "dingtalk":
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
# Format: scheduler_<receiver>_<task_id> to ensure isolation
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
# Create context for Agent
context = Context(ContextType.TEXT, task_description)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = scheduler_session_id
# Channel-specific setup
if channel_type == "web":
import uuid
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
context["request_id"] = request_id
elif channel_type == "feishu":
context["receive_id_type"] = "chat_id" if is_group else "open_id"
context["msg"] = None
elif channel_type == "dingtalk":
# DingTalk requires msg object, set to None for scheduled tasks
context["msg"] = None
if not is_group:
sender_staff_id = action.get("dingtalk_sender_staff_id")
if sender_staff_id:
context["dingtalk_sender_staff_id"] = sender_staff_id
elif channel_type == "wecom_bot":
context["msg"] = None
# Use Agent to execute the task
# Mark this as a scheduled task execution to prevent recursive task creation
context["is_scheduled_task"] = True
try:
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
if reply and reply.content:
# Send the reply via channel
from channel.channel_factory import create_channel
try:
channel = create_channel(channel_type)
if channel:
# For web channel, register request_id
if channel_type == "web" and hasattr(channel, 'request_to_session'):
request_id = context.get("request_id")
if request_id:
channel.request_to_session[request_id] = receiver
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
# Send the reply
channel.send(reply, context)
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
else:
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
except Exception as e:
logger.error(f"[Scheduler] Failed to send result: {e}")
else:
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
except Exception as e:
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
def _execute_send_message(task: dict, agent_bridge):
"""
Execute a send_message action
Args:
task: Task dictionary
agent_bridge: AgentBridge instance
"""
try:
action = task.get("action", {})
content = action.get("content", "")
receiver = action.get("receiver")
is_group = action.get("is_group", False)
channel_type = action.get("channel_type", "unknown")
if not receiver:
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
return
# Create context for sending message
context = Context(ContextType.TEXT, content)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = receiver
# Channel-specific context setup
if channel_type == "web":
# Web channel needs request_id
import uuid
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
context["request_id"] = request_id
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
elif channel_type == "feishu":
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
# Use chat_id for groups, open_id for private chats
context["receive_id_type"] = "chat_id" if is_group else "open_id"
# Keep isgroup as is, but set msg to None (no original message to reply to)
# Feishu channel will detect this and send as new message instead of reply
context["msg"] = None
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
elif channel_type == "dingtalk":
# DingTalk channel setup
context["msg"] = None
# 如果是单聊,需要传递 sender_staff_id
if not is_group:
sender_staff_id = action.get("dingtalk_sender_staff_id")
if sender_staff_id:
context["dingtalk_sender_staff_id"] = sender_staff_id
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
else:
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
elif channel_type == "wecom_bot":
context["msg"] = None
elif channel_type == "qq":
context["msg"] = None
# Create reply
reply = Reply(ReplyType.TEXT, content)
# Get channel and send
from channel.channel_factory import create_channel
try:
channel = create_channel(channel_type)
if channel:
# For web channel, register the request_id to session mapping
if channel_type == "web" and hasattr(channel, 'request_to_session'):
channel.request_to_session[request_id] = receiver
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
channel.send(reply, context)
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
else:
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
except Exception as e:
logger.error(f"[Scheduler] Failed to send message: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
def _execute_tool_call(task: dict, agent_bridge):
"""
Execute a tool_call action
Args:
task: Task dictionary
agent_bridge: AgentBridge instance
"""
try:
action = task.get("action", {})
# Support both old and new field names
tool_name = action.get("call_name") or action.get("tool_name")
tool_params = action.get("call_params") or action.get("tool_params", {})
result_prefix = action.get("result_prefix", "")
receiver = action.get("receiver")
is_group = action.get("is_group", False)
channel_type = action.get("channel_type", "unknown")
if not tool_name:
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
return
if not receiver:
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
return
# Get tool manager and create tool instance
from agent.tools.tool_manager import ToolManager
tool_manager = ToolManager()
tool = tool_manager.create_tool(tool_name)
if not tool:
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
return
# Execute tool
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
result = tool.execute(tool_params)
# Get result content
if hasattr(result, 'result'):
content = result.result
else:
content = str(result)
# Add prefix if specified
if result_prefix:
content = f"{result_prefix}\n\n{content}"
# Send result as message
context = Context(ContextType.TEXT, content)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = receiver
# Channel-specific context setup
if channel_type == "web":
# Web channel needs request_id
import uuid
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
context["request_id"] = request_id
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
elif channel_type == "feishu":
context["receive_id_type"] = "chat_id" if is_group else "open_id"
context["msg"] = None
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
elif channel_type == "wecom_bot":
context["msg"] = None
reply = Reply(ReplyType.TEXT, content)
# Get channel and send
from channel.channel_factory import create_channel
try:
channel = create_channel(channel_type)
if channel:
if channel_type == "web" and hasattr(channel, 'request_to_session'):
channel.request_to_session[request_id] = receiver
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
channel.send(reply, context)
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
else:
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
except Exception as e:
logger.error(f"[Scheduler] Failed to send tool result: {e}")
except Exception as e:
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
def _execute_skill_call(task: dict, agent_bridge):
"""
Execute a skill_call action by asking Agent to run the skill
Args:
task: Task dictionary
agent_bridge: AgentBridge instance
"""
try:
action = task.get("action", {})
# Support both old and new field names
skill_name = action.get("call_name") or action.get("skill_name")
skill_params = action.get("call_params") or action.get("skill_params", {})
result_prefix = action.get("result_prefix", "")
receiver = action.get("receiver")
is_group = action.get("isgroup", False)
channel_type = action.get("channel_type", "unknown")
if not skill_name:
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
return
if not receiver:
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
return
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
# Format: scheduler_<receiver>_<task_id> to ensure isolation
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
# Build a natural language query for the Agent to execute the skill
# Format: "Use skill-name to do something with params"
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
query = f"Use {skill_name} skill"
if param_str:
query += f" with {param_str}"
# Create context for Agent
context = Context(ContextType.TEXT, query)
context["receiver"] = receiver
context["isgroup"] = is_group
context["session_id"] = scheduler_session_id
# Channel-specific setup
if channel_type == "web":
import uuid
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
context["request_id"] = request_id
elif channel_type == "feishu":
context["receive_id_type"] = "chat_id" if is_group else "open_id"
context["msg"] = None
elif channel_type == "wecom_bot":
context["msg"] = None
# Use Agent to execute the skill
try:
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
if reply and reply.content:
content = reply.content
# Add prefix if specified
if result_prefix:
content = f"{result_prefix}\n\n{content}"
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
else:
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
except Exception as e:
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
except Exception as e:
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
import traceback
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
def attach_scheduler_to_tool(tool, context: Context = None):
"""
Attach scheduler components to a SchedulerTool instance
Args:
tool: SchedulerTool instance
context: Current context (optional)
"""
if _task_store:
tool.task_store = _task_store
if context:
tool.current_context = context
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
if not tool.config:
tool.config = {}
tool.config["channel_type"] = channel_type

View File

@@ -0,0 +1,213 @@
"""
Background scheduler service for executing scheduled tasks
"""
import time
import threading
from datetime import datetime, timedelta
from typing import Callable, Optional
from croniter import croniter
from common.log import logger
class SchedulerService:
"""
Background service that executes scheduled tasks
"""
def __init__(self, task_store, execute_callback: Callable):
"""
Initialize scheduler service
Args:
task_store: TaskStore instance
execute_callback: Function to call when executing a task
"""
self.task_store = task_store
self.execute_callback = execute_callback
self.running = False
self.thread = None
self._lock = threading.Lock()
def start(self):
"""Start the scheduler service"""
with self._lock:
if self.running:
logger.warning("[Scheduler] Service already running")
return
self.running = True
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
logger.debug("[Scheduler] Service started")
def stop(self):
"""Stop the scheduler service"""
with self._lock:
if not self.running:
return
self.running = False
if self.thread:
self.thread.join(timeout=5)
logger.info("[Scheduler] Service stopped")
def _run_loop(self):
"""Main scheduler loop"""
logger.debug("[Scheduler] Scheduler loop started")
while self.running:
try:
self._check_and_execute_tasks()
except Exception as e:
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
time.sleep(30)
def _check_and_execute_tasks(self):
"""Check for due tasks and execute them"""
now = datetime.now()
tasks = self.task_store.list_tasks(enabled_only=True)
for task in tasks:
try:
# Check if task is due
if self._is_task_due(task, now):
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
self._execute_task(task)
# Update next run time
next_run = self._calculate_next_run(task, now)
if next_run:
self.task_store.update_task(task['id'], {
"next_run_at": next_run.isoformat(),
"last_run_at": now.isoformat()
})
else:
# One-time task completed, remove it
self.task_store.delete_task(task['id'])
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
except Exception as e:
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
def _is_task_due(self, task: dict, now: datetime) -> bool:
"""
Check if a task is due to run
Args:
task: Task dictionary
now: Current datetime
Returns:
True if task should run now
"""
next_run_str = task.get("next_run_at")
if not next_run_str:
# Calculate initial next_run_at
next_run = self._calculate_next_run(task, now)
if next_run:
self.task_store.update_task(task['id'], {
"next_run_at": next_run.isoformat()
})
return False
return False
try:
next_run = datetime.fromisoformat(next_run_str)
# Check if task is overdue (e.g., service restart)
if next_run < now:
time_diff = (now - next_run).total_seconds()
# If overdue by more than 5 minutes, skip this run and schedule next
if time_diff > 300: # 5 minutes
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
# For one-time tasks, remove them directly
schedule = task.get("schedule", {})
if schedule.get("type") == "once":
self.task_store.delete_task(task['id'])
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
return False
# For recurring tasks, calculate next run from now
next_next_run = self._calculate_next_run(task, now)
if next_next_run:
self.task_store.update_task(task['id'], {
"next_run_at": next_next_run.isoformat()
})
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
return False
return now >= next_run
except Exception:
return False
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
"""
Calculate next run time for a task
Args:
task: Task dictionary
from_time: Calculate from this time
Returns:
Next run datetime or None for one-time tasks
"""
schedule = task.get("schedule", {})
schedule_type = schedule.get("type")
if schedule_type == "cron":
# Cron expression
expression = schedule.get("expression")
if not expression:
return None
try:
cron = croniter(expression, from_time)
return cron.get_next(datetime)
except Exception as e:
logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
return None
elif schedule_type == "interval":
# Interval in seconds
seconds = schedule.get("seconds", 0)
if seconds <= 0:
return None
return from_time + timedelta(seconds=seconds)
elif schedule_type == "once":
# One-time task at specific time
run_at_str = schedule.get("run_at")
if not run_at_str:
return None
try:
run_at = datetime.fromisoformat(run_at_str)
# Only return if in the future
if run_at > from_time:
return run_at
except Exception:
pass
return None
return None
def _execute_task(self, task: dict):
"""
Execute a task
Args:
task: Task dictionary
"""
try:
# Call the execute callback
self.execute_callback(task)
except Exception as e:
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
# Update task with error
self.task_store.update_task(task['id'], {
"last_error": str(e),
"last_error_at": datetime.now().isoformat()
})

View File

@@ -0,0 +1,443 @@
"""
Scheduler tool for creating and managing scheduled tasks
"""
import uuid
from datetime import datetime
from typing import Any, Dict, Optional
from croniter import croniter
from agent.tools.base_tool import BaseTool, ToolResult
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
class SchedulerTool(BaseTool):
"""
Tool for managing scheduled tasks (reminders, notifications, etc.)
"""
name: str = "scheduler"
description: str = (
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
"使用方法:\n"
"- 创建action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
"- 查询action='list' / action='get', task_id='任务ID'\n"
"- 管理action='delete/enable/disable', task_id='任务ID'\n\n"
"调度类型:\n"
"- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
"- interval: 固定间隔(秒)如3600表示每小时\n"
"- cron: cron表达式'0 8 * * *'表示每天8点\n\n"
"注意:'X秒后'用once+相对时间,'每X秒'用interval"
)
params: dict = {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["create", "list", "get", "delete", "enable", "disable"],
"description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
},
"task_id": {
"type": "string",
"description": "任务ID (用于 get/delete/enable/disable 操作)"
},
"name": {
"type": "string",
"description": "任务名称 (用于 create 操作)"
},
"message": {
"type": "string",
"description": "固定消息内容 (与ai_task二选一)"
},
"ai_task": {
"type": "string",
"description": "AI任务描述 (与message二选一)用于定时让AI执行的任务"
},
"schedule_type": {
"type": "string",
"enum": ["cron", "interval", "once"],
"description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
},
"schedule_value": {
"type": "string",
"description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
}
},
"required": ["action"]
}
def __init__(self, config: dict = None):
super().__init__()
self.config = config or {}
# Will be set by agent bridge
self.task_store = None
self.current_context = None
def execute(self, params: dict) -> ToolResult:
"""
Execute scheduler operations
Args:
params: Dictionary containing:
- action: Operation type (create/list/get/delete/enable/disable)
- Other parameters depending on action
Returns:
ToolResult object
"""
# Extract parameters
action = params.get("action")
kwargs = params
if not self.task_store:
return ToolResult.fail("错误: 定时任务系统未初始化")
try:
if action == "create":
result = self._create_task(**kwargs)
return ToolResult.success(result)
elif action == "list":
result = self._list_tasks(**kwargs)
return ToolResult.success(result)
elif action == "get":
result = self._get_task(**kwargs)
return ToolResult.success(result)
elif action == "delete":
result = self._delete_task(**kwargs)
return ToolResult.success(result)
elif action == "enable":
result = self._enable_task(**kwargs)
return ToolResult.success(result)
elif action == "disable":
result = self._disable_task(**kwargs)
return ToolResult.success(result)
else:
return ToolResult.fail(f"未知操作: {action}")
except Exception as e:
logger.error(f"[SchedulerTool] Error: {e}")
return ToolResult.fail(f"操作失败: {str(e)}")
def _create_task(self, **kwargs) -> str:
"""Create a new scheduled task"""
name = kwargs.get("name")
message = kwargs.get("message")
ai_task = kwargs.get("ai_task")
schedule_type = kwargs.get("schedule_type")
schedule_value = kwargs.get("schedule_value")
# Validate required fields
if not name:
return "错误: 缺少任务名称 (name)"
# Check that exactly one of message/ai_task is provided
if not message and not ai_task:
return "错误: 必须提供 message固定消息或 ai_taskAI任务之一"
if message and ai_task:
return "错误: message 和 ai_task 只能提供其中一个"
if not schedule_type:
return "错误: 缺少调度类型 (schedule_type)"
if not schedule_value:
return "错误: 缺少调度值 (schedule_value)"
# Validate schedule
schedule = self._parse_schedule(schedule_type, schedule_value)
if not schedule:
return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
# Get context info for receiver
if not self.current_context:
return "错误: 无法获取当前对话上下文"
context = self.current_context
# Create task
task_id = str(uuid.uuid4())[:8]
# Build action based on message or ai_task
if message:
action = {
"type": "send_message",
"content": message,
"receiver": context.get("receiver"),
"receiver_name": self._get_receiver_name(context),
"is_group": context.get("isgroup", False),
"channel_type": self.config.get("channel_type", "unknown")
}
else: # ai_task
action = {
"type": "agent_task",
"task_description": ai_task,
"receiver": context.get("receiver"),
"receiver_name": self._get_receiver_name(context),
"is_group": context.get("isgroup", False),
"channel_type": self.config.get("channel_type", "unknown")
}
# 针对钉钉单聊,额外存储 sender_staff_id
msg = context.kwargs.get("msg")
if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
action["dingtalk_sender_staff_id"] = msg.sender_staff_id
task_data = {
"id": task_id,
"name": name,
"enabled": True,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"schedule": schedule,
"action": action
}
# Calculate initial next_run_at
next_run = self._calculate_next_run(task_data)
if next_run:
task_data["next_run_at"] = next_run.isoformat()
# Save task
self.task_store.add_task(task_data)
# Format response
schedule_desc = self._format_schedule_description(schedule)
receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
if message:
content_desc = f"💬 固定消息: {message}"
else:
content_desc = f"🤖 AI任务: {ai_task}"
return (
f"✅ 定时任务创建成功\n\n"
f"📋 任务ID: {task_id}\n"
f"📝 名称: {name}\n"
f"⏰ 调度: {schedule_desc}\n"
f"👤 接收者: {receiver_desc}\n"
f"{content_desc}\n"
f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
)
def _list_tasks(self, **kwargs) -> str:
"""List all tasks"""
tasks = self.task_store.list_tasks()
if not tasks:
return "📋 暂无定时任务"
lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
for task in tasks:
status = "" if task.get("enabled", True) else ""
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
next_run = task.get("next_run_at")
next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
lines.append(
f"{status} [{task['id']}] {task['name']}\n"
f"{schedule_desc} | 下次: {next_run_str}"
)
return "\n".join(lines)
def _get_task(self, **kwargs) -> str:
"""Get task details"""
task_id = kwargs.get("task_id")
if not task_id:
return "错误: 缺少任务ID (task_id)"
task = self.task_store.get_task(task_id)
if not task:
return f"错误: 任务 '{task_id}' 不存在"
status = "启用" if task.get("enabled", True) else "禁用"
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
action = task.get("action", {})
next_run = task.get("next_run_at")
next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
last_run = task.get("last_run_at")
last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
return (
f"📋 任务详情\n\n"
f"ID: {task['id']}\n"
f"名称: {task['name']}\n"
f"状态: {status}\n"
f"调度: {schedule_desc}\n"
f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
f"消息: {action.get('content')}\n"
f"下次执行: {next_run_str}\n"
f"上次执行: {last_run_str}\n"
f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
)
def _delete_task(self, **kwargs) -> str:
"""Delete a task"""
task_id = kwargs.get("task_id")
if not task_id:
return "错误: 缺少任务ID (task_id)"
task = self.task_store.get_task(task_id)
if not task:
return f"错误: 任务 '{task_id}' 不存在"
self.task_store.delete_task(task_id)
return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
def _enable_task(self, **kwargs) -> str:
"""Enable a task"""
task_id = kwargs.get("task_id")
if not task_id:
return "错误: 缺少任务ID (task_id)"
task = self.task_store.get_task(task_id)
if not task:
return f"错误: 任务 '{task_id}' 不存在"
self.task_store.enable_task(task_id, True)
return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
def _disable_task(self, **kwargs) -> str:
"""Disable a task"""
task_id = kwargs.get("task_id")
if not task_id:
return "错误: 缺少任务ID (task_id)"
task = self.task_store.get_task(task_id)
if not task:
return f"错误: 任务 '{task_id}' 不存在"
self.task_store.enable_task(task_id, False)
return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
"""Parse and validate schedule configuration"""
try:
if schedule_type == "cron":
# Validate cron expression
croniter(schedule_value)
return {"type": "cron", "expression": schedule_value}
elif schedule_type == "interval":
# Parse interval in seconds
seconds = int(schedule_value)
if seconds <= 0:
return None
return {"type": "interval", "seconds": seconds}
elif schedule_type == "once":
# Parse datetime - support both relative and absolute time
# Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
if schedule_value.startswith("+"):
import re
match = re.match(r'\+(\d+)([smhd])', schedule_value)
if match:
amount = int(match.group(1))
unit = match.group(2)
from datetime import timedelta
now = datetime.now()
if unit == 's': # seconds
target_time = now + timedelta(seconds=amount)
elif unit == 'm': # minutes
target_time = now + timedelta(minutes=amount)
elif unit == 'h': # hours
target_time = now + timedelta(hours=amount)
elif unit == 'd': # days
target_time = now + timedelta(days=amount)
else:
return None
return {"type": "once", "run_at": target_time.isoformat()}
else:
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
return None
else:
# Absolute time in ISO format
datetime.fromisoformat(schedule_value)
return {"type": "once", "run_at": schedule_value}
except Exception as e:
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
return None
return None
def _calculate_next_run(self, task: dict) -> Optional[datetime]:
"""Calculate next run time for a task"""
schedule = task.get("schedule", {})
schedule_type = schedule.get("type")
now = datetime.now()
if schedule_type == "cron":
expression = schedule.get("expression")
cron = croniter(expression, now)
return cron.get_next(datetime)
elif schedule_type == "interval":
seconds = schedule.get("seconds", 0)
from datetime import timedelta
return now + timedelta(seconds=seconds)
elif schedule_type == "once":
run_at_str = schedule.get("run_at")
return datetime.fromisoformat(run_at_str)
return None
def _format_schedule_description(self, schedule: dict) -> str:
"""Format schedule as human-readable description"""
schedule_type = schedule.get("type")
if schedule_type == "cron":
expr = schedule.get("expression", "")
# Try to provide friendly description
if expr == "0 9 * * *":
return "每天 9:00"
elif expr == "0 */1 * * *":
return "每小时"
elif expr == "*/30 * * * *":
return "每30分钟"
else:
return f"Cron: {expr}"
elif schedule_type == "interval":
seconds = schedule.get("seconds", 0)
if seconds >= 86400:
days = seconds // 86400
return f"{days}"
elif seconds >= 3600:
hours = seconds // 3600
return f"{hours} 小时"
elif seconds >= 60:
minutes = seconds // 60
return f"{minutes} 分钟"
else:
return f"{seconds}"
elif schedule_type == "once":
run_at = schedule.get("run_at", "")
try:
dt = datetime.fromisoformat(run_at)
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
except Exception:
return "一次性"
return "未知"
def _get_receiver_name(self, context: Context) -> str:
"""Get receiver name from context"""
try:
msg = context.get("msg")
if msg:
if context.get("isgroup"):
return msg.other_user_nickname or "群聊"
else:
return msg.from_user_nickname or "用户"
except Exception:
pass
return "未知"

View File

@@ -0,0 +1,201 @@
"""
Task storage management for scheduler
"""
import json
import os
import threading
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from common.utils import expand_path
class TaskStore:
"""
Manages persistent storage of scheduled tasks
"""
def __init__(self, store_path: str = None):
"""
Initialize task store
Args:
store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
"""
if store_path is None:
# Default to ~/cow/scheduler/tasks.json
home = expand_path("~")
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
self.store_path = store_path
self.lock = threading.Lock()
self._ensure_store_dir()
def _ensure_store_dir(self):
"""Ensure the storage directory exists"""
store_dir = os.path.dirname(self.store_path)
os.makedirs(store_dir, exist_ok=True)
def load_tasks(self) -> Dict[str, dict]:
"""
Load all tasks from storage
Returns:
Dictionary of task_id -> task_data
"""
with self.lock:
if not os.path.exists(self.store_path):
return {}
try:
with open(self.store_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get("tasks", {})
except Exception as e:
print(f"Error loading tasks: {e}")
return {}
def save_tasks(self, tasks: Dict[str, dict]):
"""
Save all tasks to storage
Args:
tasks: Dictionary of task_id -> task_data
"""
with self.lock:
try:
# Create backup
if os.path.exists(self.store_path):
backup_path = f"{self.store_path}.bak"
try:
with open(self.store_path, 'r') as src:
with open(backup_path, 'w') as dst:
dst.write(src.read())
except Exception:
pass
# Save tasks
data = {
"version": 1,
"updated_at": datetime.now().isoformat(),
"tasks": tasks
}
with open(self.store_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"Error saving tasks: {e}")
raise
def add_task(self, task: dict) -> bool:
"""
Add a new task
Args:
task: Task data dictionary
Returns:
True if successful
"""
tasks = self.load_tasks()
task_id = task.get("id")
if not task_id:
raise ValueError("Task must have an 'id' field")
if task_id in tasks:
raise ValueError(f"Task with id '{task_id}' already exists")
tasks[task_id] = task
self.save_tasks(tasks)
return True
def update_task(self, task_id: str, updates: dict) -> bool:
"""
Update an existing task
Args:
task_id: Task ID
updates: Dictionary of fields to update
Returns:
True if successful
"""
tasks = self.load_tasks()
if task_id not in tasks:
raise ValueError(f"Task '{task_id}' not found")
# Update fields
tasks[task_id].update(updates)
tasks[task_id]["updated_at"] = datetime.now().isoformat()
self.save_tasks(tasks)
return True
def delete_task(self, task_id: str) -> bool:
"""
Delete a task
Args:
task_id: Task ID
Returns:
True if successful
"""
tasks = self.load_tasks()
if task_id not in tasks:
raise ValueError(f"Task '{task_id}' not found")
del tasks[task_id]
self.save_tasks(tasks)
return True
def get_task(self, task_id: str) -> Optional[dict]:
"""
Get a specific task
Args:
task_id: Task ID
Returns:
Task data or None if not found
"""
tasks = self.load_tasks()
return tasks.get(task_id)
def list_tasks(self, enabled_only: bool = False) -> List[dict]:
"""
List all tasks
Args:
enabled_only: If True, only return enabled tasks
Returns:
List of task dictionaries
"""
tasks = self.load_tasks()
task_list = list(tasks.values())
if enabled_only:
task_list = [t for t in task_list if t.get("enabled", True)]
# Sort by next_run_at
task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
return task_list
def enable_task(self, task_id: str, enabled: bool = True) -> bool:
"""
Enable or disable a task
Args:
task_id: Task ID
enabled: True to enable, False to disable
Returns:
True if successful
"""
return self.update_task(task_id, {"enabled": enabled})

View File

@@ -0,0 +1,3 @@
from .send import Send
__all__ = ['Send']

171
agent/tools/send/send.py Normal file
View File

@@ -0,0 +1,171 @@
"""
Send tool - Send files to the user
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
class Send(BaseTool):
"""Tool for sending files to the user"""
name: str = "send"
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
},
"message": {
"type": "string",
"description": "Optional message to accompany the file"
}
},
"required": ["path"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Supported file types
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file send operation
:param args: Contains file path and optional message
:return: File metadata for channel to send
"""
path = args.get("path", "").strip()
message = args.get("message", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable
if not os.access(absolute_path, os.R_OK):
return ToolResult.fail(f"Error: File is not readable: {path}")
# Get file info
file_ext = Path(absolute_path).suffix.lower()
file_size = os.path.getsize(absolute_path)
file_name = Path(absolute_path).name
# Determine file type
if file_ext in self.image_extensions:
file_type = "image"
mime_type = self._get_image_mime_type(file_ext)
elif file_ext in self.video_extensions:
file_type = "video"
mime_type = self._get_video_mime_type(file_ext)
elif file_ext in self.audio_extensions:
file_type = "audio"
mime_type = self._get_audio_mime_type(file_ext)
elif file_ext in self.document_extensions:
file_type = "document"
mime_type = self._get_document_mime_type(file_ext)
else:
file_type = "file"
mime_type = "application/octet-stream"
# Return file_to_send metadata
result = {
"type": "file_to_send",
"file_type": file_type,
"path": absolute_path,
"file_name": file_name,
"mime_type": mime_type,
"size": file_size,
"size_formatted": self._format_size(file_size),
"message": message or f"正在发送 {file_name}"
}
try:
from common.cloud_client import get_website_base_url, copy_send_file
# Do nothing when in local env
if get_website_base_url():
url = copy_send_file(absolute_path, self.cwd)
if url:
result["url"] = url
except Exception:
pass
return ToolResult.success(result)
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _get_image_mime_type(self, ext: str) -> str:
"""Get MIME type for image"""
mime_map = {
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
'.png': 'image/png', '.gif': 'image/gif',
'.webp': 'image/webp', '.bmp': 'image/bmp',
'.svg': 'image/svg+xml', '.ico': 'image/x-icon'
}
return mime_map.get(ext, 'image/jpeg')
def _get_video_mime_type(self, ext: str) -> str:
"""Get MIME type for video"""
mime_map = {
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
'.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
'.webm': 'video/webm', '.flv': 'video/x-flv'
}
return mime_map.get(ext, 'video/mp4')
def _get_audio_mime_type(self, ext: str) -> str:
"""Get MIME type for audio"""
mime_map = {
'.mp3': 'audio/mpeg', '.wav': 'audio/wav',
'.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
'.flac': 'audio/flac', '.aac': 'audio/aac'
}
return mime_map.get(ext, 'audio/mpeg')
def _get_document_mime_type(self, ext: str) -> str:
"""Get MIME type for document"""
mime_map = {
'.pdf': 'application/pdf',
'.doc': 'application/msword',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.xls': 'application/vnd.ms-excel',
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'.ppt': 'application/vnd.ms-powerpoint',
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'.txt': 'text/plain',
'.md': 'text/markdown'
}
return mime_map.get(ext, 'application/octet-stream')
def _format_size(self, size_bytes: int) -> str:
"""Format file size in human-readable format"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024.0:
return f"{size_bytes:.1f}{unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f}TB"

248
agent/tools/tool_manager.py Normal file
View File

@@ -0,0 +1,248 @@
import importlib
import importlib.util
from pathlib import Path
from typing import Dict, Any, Type
from agent.tools.base_tool import BaseTool
from common.log import logger
from config import conf
class ToolManager:
"""
Tool manager for managing tools.
"""
_instance = None
def __new__(cls):
"""Singleton pattern to ensure only one instance of ToolManager exists."""
if cls._instance is None:
cls._instance = super(ToolManager, cls).__new__(cls)
cls._instance.tool_classes = {} # Store tool classes instead of instances
cls._instance._initialized = False
return cls._instance
def __init__(self):
# Initialize only once
if not hasattr(self, 'tool_classes'):
self.tool_classes = {} # Dictionary to store tool classes
def load_tools(self, tools_dir: str = "", config_dict=None):
"""
Load tools from both directory and configuration.
:param tools_dir: Directory to scan for tool modules
"""
if tools_dir:
self._load_tools_from_directory(tools_dir)
self._configure_tools_from_config()
else:
self._load_tools_from_init()
self._configure_tools_from_config(config_dict)
def _load_tools_from_init(self) -> bool:
"""
Load tool classes from tools.__init__.__all__
:return: True if tools were loaded, False otherwise
"""
try:
# Try to import the tools package
tools_package = importlib.import_module("agent.tools")
# Check if __all__ is defined
if hasattr(tools_package, "__all__"):
tool_classes = tools_package.__all__
# Import each tool class directly from the tools package
for class_name in tool_classes:
try:
# Skip base classes
if class_name in ["BaseTool", "ToolManager"]:
continue
# Get the class directly from the tools package
if hasattr(tools_package, class_name):
cls = getattr(tools_package, class_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Skip memory tools (they need special initialization with memory_manager)
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
continue
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
except ImportError as e:
# Handle missing dependencies with helpful messages
error_msg = str(e)
if "playwright" in error_msg:
logger.warning(
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
f" To enable browser tool, run:\n"
f" pip install playwright\n"
f" playwright install chromium"
)
elif "markdownify" in error_msg:
logger.warning(
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
f" Install with: pip install markdownify"
)
else:
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
except Exception as e:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
logger.error(f"Error importing class {class_name}: {e}")
return len(self.tool_classes) > 0
return False
except ImportError:
logger.warning("Could not import agent.tools package")
return False
except Exception as e:
logger.error(f"Error loading tools from __init__.__all__: {e}")
return False
def _load_tools_from_directory(self, tools_dir: str):
"""Dynamically load tool classes from directory"""
tools_path = Path(tools_dir)
# Traverse all .py files
for py_file in tools_path.rglob("*.py"):
# Skip initialization files and base tool files
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
continue
# Get module name
module_name = py_file.stem
try:
# Load module directly from file
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find tool classes in the module
for attr_name in dir(module):
cls = getattr(module, attr_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Skip memory tools (they need special initialization with memory_manager)
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
continue
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
except ImportError as e:
# Handle missing dependencies with helpful messages
error_msg = str(e)
if "playwright" in error_msg:
logger.warning(
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
f" To enable browser tool, run:\n"
f" pip install playwright\n"
f" playwright install chromium"
)
elif "markdownify" in error_msg:
logger.warning(
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
f" Install with: pip install markdownify"
)
else:
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
except Exception as e:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
print(f"Error importing module {py_file}: {e}")
def _configure_tools_from_config(self, config_dict=None):
"""Configure tool classes based on configuration file"""
try:
# Get tools configuration
tools_config = config_dict or conf().get("tools", {})
# Record tools that are configured but not loaded
missing_tools = []
# Store configurations for later use when instantiating
self.tool_configs = tools_config
# Check which configured tools are missing
for tool_name in tools_config:
if tool_name not in self.tool_classes:
missing_tools.append(tool_name)
# If there are missing tools, record warnings
if missing_tools:
for tool_name in missing_tools:
if tool_name == "browser":
logger.warning(
f"[ToolManager] Browser tool is configured but not loaded.\n"
f" To enable browser tool, run:\n"
f" pip install playwright\n"
f" playwright install chromium"
)
elif tool_name == "google_search":
logger.warning(
f"[ToolManager] Google Search tool is configured but may need API key.\n"
f" Get API key from: https://serper.dev\n"
f" Configure in config.json: tools.google_search.api_key"
)
else:
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
except Exception as e:
logger.error(f"Error configuring tools from config: {e}")
def create_tool(self, name: str) -> BaseTool:
"""
Get a new instance of a tool by name.
:param name: The name of the tool to get.
:return: A new instance of the tool or None if not found.
"""
tool_class = self.tool_classes.get(name)
if tool_class:
# Create a new instance
tool_instance = tool_class()
# Apply configuration if available
if hasattr(self, 'tool_configs') and name in self.tool_configs:
tool_instance.config = self.tool_configs[name]
return tool_instance
return None
def list_tools(self) -> dict:
"""
Get information about all loaded tools.
:return: A dictionary with tool information.
"""
result = {}
for name, tool_class in self.tool_classes.items():
# Create a temporary instance to get schema
temp_instance = tool_class()
result[name] = {
"description": temp_instance.description,
"parameters": temp_instance.get_json_schema()
}
return result

View File

@@ -0,0 +1,40 @@
from .truncate import (
truncate_head,
truncate_tail,
truncate_line,
format_size,
TruncationResult,
DEFAULT_MAX_LINES,
DEFAULT_MAX_BYTES,
GREP_MAX_LINE_LENGTH
)
from .diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string,
FuzzyMatchResult
)
__all__ = [
'truncate_head',
'truncate_tail',
'truncate_line',
'format_size',
'TruncationResult',
'DEFAULT_MAX_LINES',
'DEFAULT_MAX_BYTES',
'GREP_MAX_LINE_LENGTH',
'strip_bom',
'detect_line_ending',
'normalize_to_lf',
'restore_line_endings',
'normalize_for_fuzzy_match',
'fuzzy_find_text',
'generate_diff_string',
'FuzzyMatchResult'
]

167
agent/tools/utils/diff.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Diff tools for file editing
Provides fuzzy matching and diff generation functionality
"""
import difflib
import re
from typing import Optional, Tuple
def strip_bom(text: str) -> Tuple[str, str]:
"""
Remove BOM (Byte Order Mark)
:param text: Original text
:return: (BOM, text after removing BOM)
"""
if text.startswith('\ufeff'):
return '\ufeff', text[1:]
return '', text
def detect_line_ending(text: str) -> str:
"""
Detect line ending type
:param text: Text content
:return: Line ending type ('\r\n' or '\n')
"""
if '\r\n' in text:
return '\r\n'
return '\n'
def normalize_to_lf(text: str) -> str:
"""
Normalize all line endings to LF (\n)
:param text: Original text
:return: Normalized text
"""
return text.replace('\r\n', '\n').replace('\r', '\n')
def restore_line_endings(text: str, original_ending: str) -> str:
"""
Restore original line endings
:param text: LF normalized text
:param original_ending: Original line ending
:return: Text with restored line endings
"""
if original_ending == '\r\n':
return text.replace('\n', '\r\n')
return text
def normalize_for_fuzzy_match(text: str) -> str:
"""
Normalize text for fuzzy matching
Remove excess whitespace but preserve basic structure
:param text: Original text
:return: Normalized text
"""
# Compress multiple spaces to one
text = re.sub(r'[ \t]+', ' ', text)
# Remove trailing spaces
text = re.sub(r' +\n', '\n', text)
# Remove leading spaces (but preserve indentation structure, only remove excess)
lines = text.split('\n')
normalized_lines = []
for line in lines:
# Preserve indentation but normalize to multiples of single spaces
stripped = line.lstrip()
if stripped:
indent_count = len(line) - len(stripped)
# Normalize indentation (convert tabs to spaces)
normalized_indent = ' ' * indent_count
normalized_lines.append(normalized_indent + stripped)
else:
normalized_lines.append('')
return '\n'.join(normalized_lines)
class FuzzyMatchResult:
"""Fuzzy match result"""
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
self.found = found
self.index = index
self.match_length = match_length
self.content_for_replacement = content_for_replacement
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
"""
Find text in content, try exact match first, then fuzzy match
:param content: Content to search in
:param old_text: Text to find
:return: Match result
"""
# First try exact match
index = content.find(old_text)
if index != -1:
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(old_text),
content_for_replacement=content
)
# Try fuzzy match
fuzzy_content = normalize_for_fuzzy_match(content)
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
index = fuzzy_content.find(fuzzy_old_text)
if index != -1:
# Fuzzy match successful, use normalized content for replacement
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(fuzzy_old_text),
content_for_replacement=fuzzy_content
)
# Not found
return FuzzyMatchResult(found=False)
def generate_diff_string(old_content: str, new_content: str) -> dict:
"""
Generate unified diff string
:param old_content: Old content
:param new_content: New content
:return: Dictionary containing diff and first changed line number
"""
old_lines = old_content.split('\n')
new_lines = new_content.split('\n')
# Generate unified diff
diff_lines = list(difflib.unified_diff(
old_lines,
new_lines,
lineterm='',
fromfile='original',
tofile='modified'
))
# Find first changed line number
first_changed_line = None
for line in diff_lines:
if line.startswith('@@'):
# Parse @@ -1,3 +1,3 @@ format
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
if match:
first_changed_line = int(match.group(1))
break
diff_string = '\n'.join(diff_lines)
return {
'diff': diff_string,
'first_changed_line': first_changed_line
}

View File

@@ -0,0 +1,292 @@
"""
Shared truncation utilities for tool outputs.
Truncation is based on two independent limits - whichever is hit first wins:
- Line limit (default: 2000 lines)
- Byte limit (default: 50KB)
Never returns partial lines (except bash tail truncation edge case).
"""
from typing import Dict, Any, Optional, Literal, Tuple
DEFAULT_MAX_LINES = 2000
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
class TruncationResult:
"""Truncation result"""
def __init__(
self,
content: str,
truncated: bool,
truncated_by: Optional[Literal["lines", "bytes"]],
total_lines: int,
total_bytes: int,
output_lines: int,
output_bytes: int,
last_line_partial: bool = False,
first_line_exceeds_limit: bool = False,
max_lines: int = DEFAULT_MAX_LINES,
max_bytes: int = DEFAULT_MAX_BYTES
):
self.content = content
self.truncated = truncated
self.truncated_by = truncated_by
self.total_lines = total_lines
self.total_bytes = total_bytes
self.output_lines = output_lines
self.output_bytes = output_bytes
self.last_line_partial = last_line_partial
self.first_line_exceeds_limit = first_line_exceeds_limit
self.max_lines = max_lines
self.max_bytes = max_bytes
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"content": self.content,
"truncated": self.truncated,
"truncated_by": self.truncated_by,
"total_lines": self.total_lines,
"total_bytes": self.total_bytes,
"output_lines": self.output_lines,
"output_bytes": self.output_bytes,
"last_line_partial": self.last_line_partial,
"first_line_exceeds_limit": self.first_line_exceeds_limit,
"max_lines": self.max_lines,
"max_bytes": self.max_bytes
}
def format_size(bytes_count: int) -> str:
"""Format bytes as human-readable size"""
if bytes_count < 1024:
return f"{bytes_count}B"
elif bytes_count < 1024 * 1024:
return f"{bytes_count / 1024:.1f}KB"
else:
return f"{bytes_count / (1024 * 1024):.1f}MB"
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from the head (keep first N lines/bytes).
Suitable for file reads where you want to see the beginning.
Never returns partial lines. If first line exceeds byte limit,
returns empty content with first_line_exceeds_limit=True.
:param content: Content to truncate
:param max_lines: Maximum number of lines (default: 2000)
:param max_bytes: Maximum number of bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Check if first line alone exceeds byte limit
first_line_bytes = len(lines[0].encode('utf-8'))
if first_line_bytes > max_bytes:
return TruncationResult(
content="",
truncated=True,
truncated_by="bytes",
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=0,
output_bytes=0,
last_line_partial=False,
first_line_exceeds_limit=True,
max_lines=max_lines,
max_bytes=max_bytes
)
# Collect complete lines that fit
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
for i, line in enumerate(lines):
if i >= max_lines:
break
# Calculate line bytes (add 1 for newline if not first line)
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
break
output_lines_arr.append(line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from tail (keep last N lines/bytes).
Suitable for bash output where you want to see the ending content (errors, final results).
If the last line of original content exceeds byte limit, may return partial first line.
:param content: Content to truncate
:param max_lines: Maximum lines (default: 2000)
:param max_bytes: Maximum bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Work backwards from the end
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
last_line_partial = False
for i in range(len(lines) - 1, -1, -1):
if len(output_lines_arr) >= max_lines:
break
line = lines[i]
# Calculate line bytes (add newline if not the first added line)
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
# take the end portion of this line
if len(output_lines_arr) == 0:
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
output_lines_arr.insert(0, truncated_line)
output_bytes_count = len(truncated_line.encode('utf-8'))
last_line_partial = True
break
output_lines_arr.insert(0, line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=last_line_partial,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
"""
Truncate string to fit byte limit (from end).
Properly handles multi-byte UTF-8 characters.
:param text: String to truncate
:param max_bytes: Maximum bytes
:return: Truncated string
"""
encoded = text.encode('utf-8')
if len(encoded) <= max_bytes:
return text
# Start from end, skip back maxBytes
start = len(encoded) - max_bytes
# Find valid UTF-8 boundary (character start)
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
start += 1
return encoded[start:].decode('utf-8', errors='ignore')
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
"""
Truncate single line to max characters, add [truncated] suffix.
Used for grep match lines.
:param line: Line to truncate
:param max_chars: Maximum characters
:return: (truncated text, whether truncated)
"""
if len(line) <= max_chars:
return line, False
return f"{line[:max_chars]}... [truncated]", True

View File

@@ -0,0 +1 @@
from agent.tools.vision.vision import Vision

View File

@@ -0,0 +1,280 @@
"""
Vision tool - Analyze images using OpenAI-compatible Vision API.
Supports local files (auto base64-encoded) and HTTP URLs.
Providers: OpenAI (preferred) > LinkAI (fallback).
"""
import base64
import os
import subprocess
import tempfile
from typing import Any, Dict, Optional, Tuple
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from config import conf
DEFAULT_MODEL = "gpt-4.1-mini"
DEFAULT_TIMEOUT = 60
MAX_TOKENS = 1000
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
SUPPORTED_EXTENSIONS = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
}
class Vision(BaseTool):
"""Analyze images using OpenAI-compatible Vision API"""
name: str = "vision"
description: str = (
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
"Can describe content, extract text, identify objects, colors, etc. "
"Requires OPENAI_API_KEY or LINKAI_API_KEY."
)
params: dict = {
"type": "object",
"properties": {
"image": {
"type": "string",
"description": "Local file path or HTTP(S) URL of the image to analyze",
},
"question": {
"type": "string",
"description": "Question to ask about the image",
},
"model": {
"type": "string",
"description": (
f"Vision model to use (default: {DEFAULT_MODEL}). "
"Options: gpt-4.1-mini, gpt-4.1, gpt-4o-mini, gpt-4o"
),
},
},
"required": ["image", "question"],
}
def __init__(self, config: dict = None):
self.config = config or {}
@staticmethod
def is_available() -> bool:
return bool(
conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
or conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
)
def execute(self, args: Dict[str, Any]) -> ToolResult:
image = args.get("image", "").strip()
question = args.get("question", "").strip()
model = args.get("model", DEFAULT_MODEL).strip() or DEFAULT_MODEL
if not image:
return ToolResult.fail("Error: 'image' parameter is required")
if not question:
return ToolResult.fail("Error: 'question' parameter is required")
api_key, api_base, extra_headers = self._resolve_provider()
if not api_key:
return ToolResult.fail(
"Error: No API key configured for Vision.\n"
"Please configure one of the following using env_config tool:\n"
" 1. OPENAI_API_KEY (preferred): env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
" 2. LINKAI_API_KEY (fallback): env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")\n\n"
"Get your key at: https://platform.openai.com/api-keys or https://link-ai.tech"
)
try:
image_content = self._build_image_content(image)
except Exception as e:
return ToolResult.fail(f"Error: {e}")
try:
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
except requests.Timeout:
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail("Error: Failed to connect to Vision API")
except Exception as e:
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Vision API call failed - {e}")
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
if api_key:
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return api_key, self._ensure_v1(api_base), {}
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
if api_key:
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
or "https://api.link-ai.tech"
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
from common.utils import get_cloud_headers
extra = get_cloud_headers(api_key)
extra.pop("Authorization", None)
extra.pop("Content-Type", None)
return api_key, self._ensure_v1(api_base), extra
return None, "", {}
@staticmethod
def _ensure_v1(api_base: str) -> str:
"""Append /v1 if the base URL doesn't already end with a versioned path."""
if not api_base:
return api_base
# Already has /v1 or similar version suffix
if api_base.rstrip("/").split("/")[-1].startswith("v"):
return api_base
return api_base.rstrip("/") + "/v1"
def _build_image_content(self, image: str) -> dict:
"""Build the image_url content block for the API request."""
if image.startswith(("http://", "https://")):
return {"type": "image_url", "image_url": {"url": image}}
if not os.path.isfile(image):
raise FileNotFoundError(f"Image file not found: {image}")
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
mime_type = SUPPORTED_EXTENSIONS.get(ext)
if not mime_type:
raise ValueError(
f"Unsupported image format '.{ext}'. "
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
)
file_path = self._maybe_compress(image)
try:
with open(file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
finally:
if file_path != image and os.path.exists(file_path):
os.remove(file_path)
data_url = f"data:{mime_type};base64,{b64}"
return {"type": "image_url", "image_url": {"url": data_url}}
@staticmethod
def _maybe_compress(path: str) -> str:
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
file_size = os.path.getsize(path)
if file_size <= COMPRESS_THRESHOLD:
return path
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
tmp.close()
def _try_sips(max_dim: str, quality: str) -> bool:
try:
subprocess.run(
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
path, "--out", tmp.name],
capture_output=True, check=True,
)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
def _try_convert(max_dim: str, quality: str) -> bool:
try:
subprocess.run(
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
"-quality", quality, tmp.name],
capture_output=True, check=True,
)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
attempts = [
("1536", "85"),
("1536", "70"),
("1536", "50"),
]
for max_dim, quality in attempts:
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
if not ok:
continue
new_size = os.path.getsize(tmp.name)
logger.debug(f"[Vision] Compressed image "
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
f"max_dim={max_dim}, q={quality})")
if new_size <= COMPRESS_THRESHOLD:
return tmp.name
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
return tmp.name
os.remove(tmp.name)
return path
def _call_api(self, api_key: str, api_base: str, model: str,
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": question},
image_content,
],
}
],
"max_tokens": MAX_TOKENS,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
}
resp = requests.post(
f"{api_base}/chat/completions",
headers=headers,
json=payload,
timeout=DEFAULT_TIMEOUT,
)
if resp.status_code == 401:
return ToolResult.fail("Error: Invalid API key. Please check your configuration.")
if resp.status_code == 429:
return ToolResult.fail("Error: API rate limit reached. Please try again later.")
if resp.status_code != 200:
return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}")
data = resp.json()
if "error" in data:
msg = data["error"].get("message", "Unknown API error")
return ToolResult.fail(f"Error: Vision API error - {msg}")
content = ""
choices = data.get("choices", [])
if choices:
content = choices[0].get("message", {}).get("content", "")
usage = data.get("usage", {})
result = {
"model": model,
"content": content,
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}
return ToolResult.success(result)

View File

View File

@@ -0,0 +1,444 @@
"""
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
Supports:
- HTML web pages: extracts readable text content
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
"""
import os
import re
import uuid
from typing import Dict, Any, Optional, Set
from urllib.parse import urlparse, unquote
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size
from common.log import logger
DEFAULT_TIMEOUT = 30
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
"Accept": "*/*",
}
# Supported document file extensions
PDF_SUFFIXES: Set[str] = {".pdf"}
WORD_SUFFIXES: Set[str] = {".docx"}
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
_META_HTTP_EQUIV_RE = re.compile(
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
re.IGNORECASE,
)
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
"""Extract charset from Content-Type header value."""
m = _CHARSET_RE.search(content_type)
return m.group(1) if m else None
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
m = _META_CHARSET_RE.search(raw_bytes)
if m:
return m.group(1).decode("ascii", errors="ignore")
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
if m:
return m.group(1).decode("ascii", errors="ignore")
return None
def _get_url_suffix(url: str) -> str:
"""Extract file extension from URL path, ignoring query params."""
path = urlparse(url).path
return os.path.splitext(path)[-1].lower()
def _is_document_url(url: str) -> bool:
"""Check if URL points to a downloadable document file."""
suffix = _get_url_suffix(url)
return suffix in ALL_DOC_SUFFIXES
class WebFetch(BaseTool):
"""Tool for fetching web pages and remote document files"""
name: str = "web_fetch"
description: str = (
"Fetch content from a http/https URL. For web pages, extracts readable text. "
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
)
params: dict = {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
}
},
"required": ["url"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
url = args.get("url", "").strip()
if not url:
return ToolResult.fail("Error: 'url' parameter is required")
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
if _is_document_url(url):
return self._fetch_document(url)
return self._fetch_webpage(url)
# ---- Web page fetching ----
def _fetch_webpage(self, url: str) -> ToolResult:
"""Fetch and extract readable text from an HTML web page."""
parsed = urlparse(url)
try:
response = requests.get(
url,
headers=DEFAULT_HEADERS,
timeout=DEFAULT_TIMEOUT,
allow_redirects=True,
)
response.raise_for_status()
except requests.Timeout:
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
except requests.HTTPError as e:
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
except Exception as e:
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
content_type = response.headers.get("Content-Type", "")
if self._is_binary_content_type(content_type) and not _is_document_url(url):
return self._handle_download_by_content_type(url, response, content_type)
response.encoding = self._detect_encoding(response)
html = response.text
title = self._extract_title(html)
text = self._extract_text(html)
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
# ---- Document fetching ----
def _fetch_document(self, url: str) -> ToolResult:
"""Download a document file and extract its text content."""
suffix = _get_url_suffix(url)
parsed = urlparse(url)
filename = self._extract_filename(url)
tmp_dir = self._ensure_tmp_dir()
local_path = os.path.join(tmp_dir, filename)
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
try:
response = requests.get(
url,
headers=DEFAULT_HEADERS,
timeout=DEFAULT_TIMEOUT,
stream=True,
allow_redirects=True,
)
response.raise_for_status()
content_length = int(response.headers.get("Content-Length", 0))
if content_length > MAX_FILE_SIZE:
return ToolResult.fail(
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
)
downloaded = 0
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
downloaded += len(chunk)
if downloaded > MAX_FILE_SIZE:
f.close()
os.remove(local_path)
return ToolResult.fail(
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
)
f.write(chunk)
except requests.Timeout:
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
except requests.HTTPError as e:
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
except Exception as e:
self._cleanup_file(local_path)
return ToolResult.fail(f"Error: Failed to download file: {e}")
try:
text = self._parse_document(local_path, suffix)
except Exception as e:
self._cleanup_file(local_path)
return ToolResult.fail(f"Error: Failed to parse document: {e}")
if not text or not text.strip():
file_size = os.path.getsize(local_path)
return ToolResult.success(
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
f"No text content could be extracted. The file may contain only images or be encrypted."
)
truncation = truncate_head(text)
result_text = truncation.content
file_size = os.path.getsize(local_path)
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
if truncation.truncated:
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
return ToolResult.success(header + result_text)
def _parse_document(self, file_path: str, suffix: str) -> str:
"""Parse document file and return extracted text."""
if suffix in PDF_SUFFIXES:
return self._parse_pdf(file_path)
elif suffix in WORD_SUFFIXES:
return self._parse_word(file_path)
elif suffix in TEXT_SUFFIXES:
return self._parse_text(file_path)
elif suffix in SPREADSHEET_SUFFIXES:
return self._parse_spreadsheet(file_path)
elif suffix in PPT_SUFFIXES:
return self._parse_ppt(file_path)
else:
return self._parse_text(file_path)
def _parse_pdf(self, file_path: str) -> str:
"""Extract text from PDF using pypdf."""
try:
from pypdf import PdfReader
except ImportError:
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
reader = PdfReader(file_path)
text_parts = []
for page_num, page in enumerate(reader.pages, 1):
page_text = page.extract_text()
if page_text and page_text.strip():
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
return "\n\n".join(text_parts)
def _parse_word(self, file_path: str) -> str:
"""Extract text from Word documents (.docx)."""
try:
from docx import Document
except ImportError:
raise ImportError(
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
)
doc = Document(file_path)
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
return "\n\n".join(paragraphs)
def _parse_text(self, file_path: str) -> str:
"""Read plain text files (txt, md, csv, etc.)."""
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
for enc in encodings:
try:
with open(file_path, "r", encoding=enc) as f:
return f.read()
except (UnicodeDecodeError, UnicodeError):
continue
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
def _parse_spreadsheet(self, file_path: str) -> str:
"""Extract text from Excel files (.xls/.xlsx)."""
try:
import openpyxl
except ImportError:
raise ImportError(
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
)
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
result_parts = []
for sheet_name in wb.sheetnames:
ws = wb[sheet_name]
rows = []
for row in ws.iter_rows(values_only=True):
cells = [str(c) if c is not None else "" for c in row]
if any(cells):
rows.append(" | ".join(cells))
if rows:
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
wb.close()
return "\n\n".join(result_parts)
def _parse_ppt(self, file_path: str) -> str:
"""Extract text from PowerPoint files (.ppt/.pptx)."""
try:
from pptx import Presentation
except ImportError:
raise ImportError(
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
)
prs = Presentation(file_path)
text_parts = []
for slide_num, slide in enumerate(prs.slides, 1):
slide_texts = []
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
text = paragraph.text.strip()
if text:
slide_texts.append(text)
if slide_texts:
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
return "\n\n".join(text_parts)
# ---- Encoding detection ----
@staticmethod
def _detect_encoding(response: requests.Response) -> str:
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
# 1. Check Content-Type header for explicit charset
content_type = response.headers.get("Content-Type", "")
charset = _extract_charset_from_content_type(content_type)
if charset:
return charset
# 2. Scan raw bytes for HTML meta charset declaration
raw = response.content[:4096]
charset = _extract_charset_from_html_meta(raw)
if charset:
return charset
# 3. Use apparent_encoding (chardet-based detection) if confident enough
apparent = response.apparent_encoding
if apparent:
apparent_lower = apparent.lower()
# Trust CJK / Windows encodings detected by chardet
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
return apparent
# 4. Fallback
return "utf-8"
# ---- Helper methods ----
def _ensure_tmp_dir(self) -> str:
"""Ensure workspace/tmp directory exists and return its path."""
tmp_dir = os.path.join(self.cwd, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
def _extract_filename(self, url: str) -> str:
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
path = urlparse(url).path
basename = os.path.basename(unquote(path))
if not basename or basename == "/":
basename = "downloaded_file"
# Sanitize: keep only safe chars
basename = re.sub(r'[^\w.\-]', '_', basename)
short_id = uuid.uuid4().hex[:8]
return f"{short_id}_{basename}"
@staticmethod
def _cleanup_file(path: str):
"""Remove a file if it exists, ignoring errors."""
try:
if os.path.exists(path):
os.remove(path)
except Exception:
pass
@staticmethod
def _is_binary_content_type(content_type: str) -> bool:
"""Check if Content-Type indicates a binary/document response."""
binary_types = [
"application/pdf",
"application/vnd.openxmlformats",
"application/vnd.ms-excel",
"application/vnd.ms-powerpoint",
"application/octet-stream",
]
ct_lower = content_type.lower()
return any(bt in ct_lower for bt in binary_types)
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
"""Handle a URL that returned binary content instead of HTML."""
ct_lower = content_type.lower()
suffix_map = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
"application/vnd.ms-powerpoint": ".ppt",
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
}
detected_suffix = None
for ct_prefix, ext in suffix_map.items():
if ct_prefix in ct_lower:
detected_suffix = ext
break
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
# Re-fetch as document
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
else self._rewrite_url_with_suffix(url, detected_suffix))
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
@staticmethod
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
parsed = urlparse(url)
new_path = parsed.path.rstrip("/") + suffix
return parsed._replace(path=new_path).geturl()
# ---- HTML extraction (unchanged) ----
@staticmethod
def _extract_title(html: str) -> str:
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
return match.group(1).strip() if match else "Untitled"
@staticmethod
def _extract_text(html: str) -> str:
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
text = re.sub(r"<[^>]+>", "", text)
text = text.replace("&amp;", "&").replace("&lt;", "<").replace("&gt;", ">")
text = text.replace("&quot;", '"').replace("&#39;", "'").replace("&nbsp;", " ")
text = re.sub(r"[^\S\n]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
lines = [line.strip() for line in text.splitlines()]
text = "\n".join(lines)
return text.strip()

View File

@@ -0,0 +1,3 @@
from agent.tools.web_search.web_search import WebSearch
__all__ = ["WebSearch"]

View File

@@ -0,0 +1,318 @@
"""
Web Search tool - Search the web using Bocha or LinkAI search API.
Supports two backends with unified response format:
1. Bocha Search (primary, requires BOCHA_API_KEY)
2. LinkAI Search (fallback, requires LINKAI_API_KEY)
"""
import os
import json
from typing import Dict, Any, Optional
import requests
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
from config import conf
# Default timeout for API requests (seconds)
DEFAULT_TIMEOUT = 30
class WebSearch(BaseTool):
"""Tool for searching the web using Bocha or LinkAI search API"""
name: str = "web_search"
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query string"
},
"count": {
"type": "integer",
"description": "Number of results to return (1-50, default: 10)"
},
"freshness": {
"type": "string",
"description": (
"Time range filter. Options: "
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
"or date range like '2025-01-01..2025-02-01'"
)
},
"summary": {
"type": "boolean",
"description": "Whether to include text summary for each result (default: false)"
}
},
"required": ["query"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self._backend = None # Will be resolved on first execute
@staticmethod
def is_available() -> bool:
"""Check if web search is available (at least one API key is configured)"""
return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY"))
def _resolve_backend(self) -> Optional[str]:
"""
Determine which search backend to use.
Priority: Bocha > LinkAI
:return: 'bocha', 'linkai', or None
"""
if os.environ.get("BOCHA_API_KEY"):
return "bocha"
if os.environ.get("LINKAI_API_KEY"):
return "linkai"
return None
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute web search
:param args: Search parameters (query, count, freshness, summary)
:return: Search results
"""
query = args.get("query", "").strip()
if not query:
return ToolResult.fail("Error: 'query' parameter is required")
count = args.get("count", 10)
freshness = args.get("freshness", "noLimit")
summary = args.get("summary", False)
# Validate count
if not isinstance(count, int) or count < 1 or count > 50:
count = 10
# Resolve backend
backend = self._resolve_backend()
if not backend:
return ToolResult.fail(
"Error: No search API key configured. "
"Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n"
" - Bocha Search: https://open.bocha.cn\n"
" - LinkAI Search: https://link-ai.tech"
)
try:
if backend == "bocha":
return self._search_bocha(query, count, freshness, summary)
else:
return self._search_linkai(query, count, freshness)
except requests.Timeout:
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail("Error: Failed to connect to search API")
except Exception as e:
logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Search failed - {str(e)}")
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
"""
Search using Bocha API
:param query: Search query
:param count: Number of results
:param freshness: Time range filter
:param summary: Whether to include summary
:return: Formatted search results
"""
api_key = os.environ.get("BOCHA_API_KEY", "")
url = "https://api.bocha.cn/v1/web-search"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"Accept": "application/json"
}
payload = {
"query": query,
"count": count,
"freshness": freshness,
"summary": summary
}
logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}")
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
if response.status_code == 401:
return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.")
if response.status_code == 403:
return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn")
if response.status_code == 429:
return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.")
if response.status_code != 200:
return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}")
data = response.json()
# Check API-level error code
api_code = data.get("code")
if api_code is not None and api_code != 200:
msg = data.get("msg") or "Unknown error"
return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}")
# Extract and format results
return self._format_bocha_results(data, query)
def _format_bocha_results(self, data: dict, query: str) -> ToolResult:
"""
Format Bocha API response into unified result structure
:param data: Raw API response
:param query: Original query
:return: Formatted ToolResult
"""
search_data = data.get("data", {})
web_pages = search_data.get("webPages", {})
pages = web_pages.get("value", [])
if not pages:
return ToolResult.success({
"query": query,
"backend": "bocha",
"total": 0,
"results": [],
"message": "No results found"
})
results = []
for page in pages:
result = {
"title": page.get("name", ""),
"url": page.get("url", ""),
"snippet": page.get("snippet", ""),
"siteName": page.get("siteName", ""),
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
}
# Include summary only if present
if page.get("summary"):
result["summary"] = page["summary"]
results.append(result)
total = web_pages.get("totalEstimatedMatches", len(results))
return ToolResult.success({
"query": query,
"backend": "bocha",
"total": total,
"count": len(results),
"results": results
})
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
"""
Search using LinkAI plugin API
:param query: Search query
:param count: Number of results
:param freshness: Time range filter
:return: Formatted search results
"""
api_key = os.environ.get("LINKAI_API_KEY", "")
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
from common.utils import get_cloud_headers
headers = get_cloud_headers(api_key)
payload = {
"code": "web-search",
"args": {
"query": query,
"count": count,
"freshness": freshness
}
}
logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}")
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
if response.status_code == 401:
return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.")
if response.status_code != 200:
return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}")
data = response.json()
if not data.get("success"):
msg = data.get("message") or "Unknown error"
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
return self._format_linkai_results(data, query)
def _format_linkai_results(self, data: dict, query: str) -> ToolResult:
"""
Format LinkAI API response into unified result structure.
LinkAI returns the search data in data.data field, which follows
the same Bing-compatible format as Bocha.
:param data: Raw API response
:param query: Original query
:return: Formatted ToolResult
"""
raw_data = data.get("data", "")
# LinkAI may return data as a JSON string
if isinstance(raw_data, str):
try:
raw_data = json.loads(raw_data)
except (json.JSONDecodeError, TypeError):
# If data is plain text, return it as a single result
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": 1,
"count": 1,
"results": [{"content": raw_data}]
})
# If the response follows Bing-compatible structure
if isinstance(raw_data, dict):
web_pages = raw_data.get("webPages", {})
pages = web_pages.get("value", [])
if pages:
results = []
for page in pages:
result = {
"title": page.get("name", ""),
"url": page.get("url", ""),
"snippet": page.get("snippet", ""),
"siteName": page.get("siteName", ""),
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
}
if page.get("summary"):
result["summary"] = page["summary"]
results.append(result)
total = web_pages.get("totalEstimatedMatches", len(results))
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": total,
"count": len(results),
"results": results
})
# Fallback: return raw data
return ToolResult.success({
"query": query,
"backend": "linkai",
"total": 1,
"count": 1,
"results": [{"content": str(raw_data)}]
})

View File

@@ -0,0 +1,3 @@
from .write import Write
__all__ = ['Write']

View File

@@ -0,0 +1,97 @@
"""
Write tool - Write file content
Creates or overwrites files, automatically creates parent directories
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from common.utils import expand_path
class Write(BaseTool):
"""Tool for writing file content"""
name: str = "write"
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to write (relative or absolute)"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.memory_manager = self.config.get("memory_manager", None)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file write operation
:param args: Contains file path and content
:return: Operation result
"""
path = args.get("path", "").strip()
content = args.get("content", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
try:
# Create parent directory (if needed)
parent_dir = os.path.dirname(absolute_path)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(content)
# Get bytes written
bytes_written = len(content.encode('utf-8'))
# Auto-sync to memory database if this is a memory file
if self.memory_manager and 'memory/' in path:
self.memory_manager.mark_dirty()
result = {
"message": f"Successfully wrote {bytes_written} bytes to {path}",
"path": path,
"bytes_written": bytes_written
}
return ToolResult.success(result)
except PermissionError:
return ToolResult.fail(f"Error: Permission denied writing to {path}")
except Exception as e:
return ToolResult.fail(f"Error writing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = expand_path(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

282
app.py
View File

@@ -3,11 +3,262 @@
import os
import signal
import sys
import time
from channel import channel_factory
from common import const
from config import load_config
from common.log import logger
from config import load_config, conf
from plugins import *
import threading
_channel_mgr = None
def get_channel_manager():
return _channel_mgr
def _parse_channel_type(raw) -> list:
"""
Parse channel_type config value into a list of channel names.
Supports:
- single string: "feishu"
- comma-separated string: "feishu, dingtalk"
- list: ["feishu", "dingtalk"]
"""
if isinstance(raw, list):
return [ch.strip() for ch in raw if ch.strip()]
if isinstance(raw, str):
return [ch.strip() for ch in raw.split(",") if ch.strip()]
return []
class ChannelManager:
"""
Manage the lifecycle of multiple channels running concurrently.
Each channel.startup() runs in its own daemon thread.
The web channel is started as default console unless explicitly disabled.
"""
def __init__(self):
self._channels = {} # channel_name -> channel instance
self._threads = {} # channel_name -> thread
self._primary_channel = None
self._lock = threading.Lock()
self.cloud_mode = False # set to True when cloud client is active
@property
def channel(self):
"""Return the primary (first non-web) channel for backward compatibility."""
return self._primary_channel
def get_channel(self, channel_name: str):
return self._channels.get(channel_name)
def start(self, channel_names: list, first_start: bool = False):
"""
Create and start one or more channels in sub-threads.
If first_start is True, plugins and linkai client will also be initialized.
"""
with self._lock:
channels = []
for name in channel_names:
ch = channel_factory.create_channel(name)
ch.cloud_mode = self.cloud_mode
self._channels[name] = ch
channels.append((name, ch))
if self._primary_channel is None and name != "web":
self._primary_channel = ch
if self._primary_channel is None and channels:
self._primary_channel = channels[0][1]
if first_start:
PluginManager().load_plugins()
# Cloud client is optional. It is only started when
# use_linkai=True AND cloud_deployment_id is set.
# By default neither is configured, so the app runs
# entirely locally without any remote connection.
if conf().get("use_linkai") and (
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
):
try:
from common import cloud_client
threading.Thread(
target=cloud_client.start,
args=(self._primary_channel, self),
daemon=True,
).start()
except Exception:
pass
# Start web console first so its logs print cleanly,
# then start remaining channels after a brief pause.
web_entry = None
other_entries = []
for entry in channels:
if entry[0] == "web":
web_entry = entry
else:
other_entries.append(entry)
ordered = ([web_entry] if web_entry else []) + other_entries
for i, (name, ch) in enumerate(ordered):
if i > 0 and name != "web":
time.sleep(0.1)
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
self._threads[name] = t
t.start()
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
def _run_channel(self, name: str, channel):
try:
channel.startup()
except Exception as e:
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
logger.exception(e)
def stop(self, channel_name: str = None):
"""
Stop channel(s). If channel_name is given, stop only that channel;
otherwise stop all channels.
"""
# Pop under lock, then stop outside lock to avoid deadlock
with self._lock:
names = [channel_name] if channel_name else list(self._channels.keys())
to_stop = []
for name in names:
ch = self._channels.pop(name, None)
th = self._threads.pop(name, None)
to_stop.append((name, ch, th))
if channel_name and self._primary_channel is self._channels.get(channel_name):
self._primary_channel = None
for name, ch, th in to_stop:
if ch is None:
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
if th and th.is_alive():
self._interrupt_thread(th, name)
continue
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
graceful = False
if hasattr(ch, 'stop'):
try:
ch.stop()
graceful = True
except Exception as e:
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
if th and th.is_alive():
th.join(timeout=5)
if th.is_alive():
if graceful:
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
"leaving daemon thread to finish on its own")
else:
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
self._interrupt_thread(th, name)
@staticmethod
def _interrupt_thread(th: threading.Thread, name: str):
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
import ctypes
try:
tid = th.ident
if tid is None:
return
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
)
if res == 1:
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
except Exception as e:
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
def restart(self, new_channel_name: str):
"""
Restart a single channel with a new channel type.
Can be called from any thread (e.g. linkai config callback).
"""
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
self.stop(new_channel_name)
_clear_singleton_cache(new_channel_name)
time.sleep(1)
self.start([new_channel_name], first_start=False)
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
def add_channel(self, channel_name: str):
"""
Dynamically add and start a new channel.
If the channel is already running, restart it instead.
"""
with self._lock:
if channel_name in self._channels:
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
if self._channels.get(channel_name):
self.restart(channel_name)
return
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
_clear_singleton_cache(channel_name)
self.start([channel_name], first_start=False)
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
def remove_channel(self, channel_name: str):
"""
Dynamically stop and remove a running channel.
"""
with self._lock:
if channel_name not in self._channels:
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
return
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
self.stop(channel_name)
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
def _clear_singleton_cache(channel_name: str):
"""
Clear the singleton cache for the channel class so that
a new instance can be created with updated config.
"""
cls_map = {
"web": "channel.web.web_channel.WebChannel",
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
const.QQ: "channel.qq.qq_channel.QQChannel",
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
"wx": "channel.weixin.weixin_channel.WeixinChannel",
}
module_path = cls_map.get(channel_name)
if not module_path:
return
try:
parts = module_path.rsplit(".", 1)
module_name, class_name = parts[0], parts[1]
import importlib
module = importlib.import_module(module_name)
wrapper = getattr(module, class_name, None)
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
for cell in wrapper.__closure__:
try:
cell_contents = cell.cell_contents
if isinstance(cell_contents, dict):
cell_contents.clear()
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
break
except ValueError:
pass
except Exception as e:
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
def sigterm_handler_wrap(_signo):
@@ -24,6 +275,7 @@ def sigterm_handler_wrap(_signo):
def run():
global _channel_mgr
try:
# load config
load_config()
@@ -32,22 +284,28 @@ def run():
# kill signal
sigterm_handler_wrap(signal.SIGTERM)
# create channel
channel_name = conf().get("channel_type", "wx")
# Parse channel_type into a list
raw_channel = conf().get("channel_type", "web")
if "--cmd" in sys.argv:
channel_name = "terminal"
channel_names = ["terminal"]
else:
channel_names = _parse_channel_type(raw_channel)
if not channel_names:
channel_names = ["web"]
if channel_name == "wxy":
os.environ["WECHATY_LOG"] = "warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
# Auto-start web console unless explicitly disabled
web_console_enabled = conf().get("web_console", True)
if web_console_enabled and "web" not in channel_names:
channel_names.append("web")
channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU]:
PluginManager().load_plugins()
logger.info(f"[App] Starting channels: {channel_names}")
# startup channel
channel.startup()
_channel_mgr = ChannelManager()
_channel_mgr.start(channel_names, first_start=True)
while True:
time.sleep(1)
except Exception as e:
logger.error("App startup failed!")
logger.exception(e)

View File

@@ -1,46 +0,0 @@
"""
channel factory
"""
from common import const
def create_bot(bot_type):
"""
create a bot_type instance
:param bot_type: bot type code
:return: bot instance
"""
if bot_type == const.BAIDU:
# 替换Baidu Unit为Baidu文心千帆对话接口
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
# return BaiduUnitBot()
from bot.baidu.baidu_wenxin import BaiduWenxinBot
return BaiduWenxinBot()
elif bot_type == const.CHATGPT:
# ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
return ChatGPTBot()
elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot
return OpenAIBot()
elif bot_type == const.CHATGPTONAZURE:
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
return AzureChatGPTBot()
elif bot_type == const.XUNFEI:
from bot.xunfei.xunfei_spark_bot import XunFeiBot
return XunFeiBot()
elif bot_type == const.LINKAI:
from bot.linkai.link_ai_bot import LinkAIBot
return LinkAIBot()
elif bot_type == const.CLAUDEAI:
from bot.claude.claude_ai_bot import ClaudeAIBot
return ClaudeAIBot()
raise RuntimeError

View File

@@ -1,193 +0,0 @@
# encoding:utf-8
import time
import openai
import openai.error
import requests
from bot.bot import Bot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from common.token_bucket import TokenBucket
from config import conf, load_config
# OpenAI对话模型API (可用)
class ChatGPTBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
# set the default api_key
openai.api_key = conf().get("open_ai_api_key")
if conf().get("open_ai_api_base"):
openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get("proxy")
if proxy:
openai.proxy = proxy
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": conf().get("top_p", 1),
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}
def reply(self, query, context=None):
# acquire reply content
if context.type == ContextType.TEXT:
logger.info("[CHATGPT] query={}".format(query))
session_id = context["session_id"]
reply = None
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
if query in clear_memory_commands:
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
elif query == "#更新配置":
load_config()
reply = Reply(ReplyType.INFO, "配置已更新")
if reply:
return reply
session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages))
api_key = context.get("openai_api_key")
model = context.get("gpt_model")
new_args = None
if model:
new_args = self.args.copy()
new_args["model"] = model
# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)
reply_content = self.reply_text(session, api_key, args=new_args)
logger.debug(
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
session_id,
reply_content["content"],
reply_content["completion_tokens"],
)
)
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
return reply
elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
:param session_id: session id
:param retry_count: retry count
:return: {}
"""
try:
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
if args is None:
args = self.args
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
# logger.debug("[CHATGPT] response={}".format(response))
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]["message"]["content"],
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIError):
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
result["content"] = "请再问我一次"
if need_retry:
time.sleep(10)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
need_retry = False
result["content"] = "我连接不到你的网络"
else:
logger.exception("[CHATGPT] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)
if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, args, retry_count + 1)
else:
return result
class AzureChatGPTBot(ChatGPTBot):
def __init__(self):
super().__init__()
openai.api_type = "azure"
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
self.args["deployment_id"] = conf().get("azure_deployment_id")
def create_img(self, query, retry_count=0, api_key=None):
api_version = "2022-08-03-preview"
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
api_key = api_key or openai.api_key
headers = {"api-key": api_key, "Content-Type": "application/json"}
try:
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
submission = requests.post(url, headers=headers, json=body)
operation_location = submission.headers["Operation-Location"]
retry_after = submission.headers["Retry-after"]
status = ""
image_url = ""
while status != "Succeeded":
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
time.sleep(int(retry_after))
response = requests.get(operation_location, headers=headers)
status = response.json()["status"]
image_url = response.json()["result"]["contentUrl"]
return True, image_url
except Exception as e:
logger.error("create image error: {}".format(e))
return False, "图片生成失败"

View File

@@ -1,222 +0,0 @@
import re
import time
import json
import uuid
from curl_cffi import requests
from bot.bot import Bot
from bot.claude.claude_ai_session import ClaudeAiSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
class ClaudeAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
self.claude_api_cookie = conf().get("claude_api_cookie")
self.proxy = conf().get("proxy")
self.con_uuid_dic = {}
if self.proxy:
self.proxies = {
"http": self.proxy,
"https": self.proxy
}
else:
self.proxies = None
self.error = ""
self.org_uuid = self.get_organization_id()
def generate_uuid(self):
random_uuid = uuid.uuid4()
random_uuid_str = str(random_uuid)
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
return formatted_uuid
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
return self._chat(query, context)
elif context.type == ContextType.IMAGE_CREATE:
ok, res = self.create_img(query, 0)
if ok:
reply = Reply(ReplyType.IMAGE_URL, res)
else:
reply = Reply(ReplyType.ERROR, res)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def get_organization_id(self):
url = "https://claude.ai/api/organizations"
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}'
}
try:
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
res = json.loads(response.text)
uuid = res[0]['uuid']
except:
if "App unavailable" in response.text:
logger.error("IP error: The IP is not allowed to be used on Claude")
self.error = "ip所在地区不被claude支持"
elif "Invalid authorization" in response.text:
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
self.error = "无法通过claude身份验证请检查cookie"
return None
return uuid
def conversation_share_check(self,session_id):
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
con_uuid = conf().get("claude_uuid")
return con_uuid
if session_id not in self.con_uuid_dic:
self.con_uuid_dic[session_id] = self.generate_uuid()
self.create_new_chat(self.con_uuid_dic[session_id])
return self.con_uuid_dic[session_id]
def check_cookie(self):
flag = self.get_organization_id()
return flag
def create_new_chat(self, con_uuid):
"""
新建claude对话实体
:param con_uuid: 对话id
:return:
"""
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
payload = json.dumps({"uuid": con_uuid, "name": ""})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': self.claude_api_cookie,
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
# Returns JSON of the newly created conversation information
return response.json()
def _chat(self, query, context, retry_count=0) -> Reply:
"""
发起对话请求
:param query: 请求提示词
:param context: 对话上下文
:param retry_count: 当前递归重试次数
:return: 回复
"""
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
return Reply(ReplyType.ERROR, "请再问我一次吧")
try:
session_id = context["session_id"]
if self.org_uuid is None:
return Reply(ReplyType.ERROR, self.error)
session = self.sessions.session_query(query, session_id)
con_uuid = self.conversation_share_check(session_id)
model = conf().get("model") or "gpt-3.5-turbo"
# remove system message
if session.messages[0].get("role") == "system":
if model == "wenxin" or model == "claude":
session.messages.pop(0)
logger.info(f"[CLAUDEAI] query={query}")
# do http request
base_url = "https://claude.ai"
payload = json.dumps({
"completion": {
"prompt": f"{query}",
"timezone": "Asia/Kolkata",
"model": "claude-2"
},
"organization_uuid": f"{self.org_uuid}",
"conversation_uuid": f"{con_uuid}",
"text": f"{query}",
"attachments": []
})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept': 'text/event-stream, text/event-stream',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
if res.status_code == 200 or "pemission" in res.text:
# execute success
decoded_data = res.content.decode("utf-8")
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
data_strings = decoded_data.split('\n')
completions = []
for data_string in data_strings:
json_str = data_string[6:].strip()
data = json.loads(json_str)
if 'completion' in data:
completions.append(data['completion'])
reply_content = ''.join(completions)
if "rate limi" in reply_content:
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
return Reply(ReplyType.ERROR, "对话达到系统速率限制与cladue同步请进入官网查看解除限制时间")
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
self.sessions.session_reply(reply_content, session_id, 100)
return Reply(ReplyType.TEXT, reply_content)
else:
flag = self.check_cookie()
if flag == None:
return Reply(ReplyType.ERROR, self.error)
response = res.json()
error = response.get("error")
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
if res.status_code >= 500:
# server error, need retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
except Exception as e:
logger.exception(e)
# retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)

View File

@@ -1,9 +0,0 @@
from bot.session_manager import Session
class ClaudeAiSession(Session):
def __init__(self, session_id, system_prompt=None, model="claude"):
super().__init__(session_id, system_prompt)
self.model = model
# claude逆向不支持role prompt
# self.reset()

690
bridge/agent_bridge.py Normal file
View File

@@ -0,0 +1,690 @@
"""
Agent Bridge - Integrates Agent system with existing COW bridge
"""
import os
from typing import Optional, List
from agent.protocol import Agent, LLMModel, LLMRequest
from bridge.agent_event_handler import AgentEventHandler
from bridge.agent_initializer import AgentInitializer
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from common.utils import expand_path
from models.openai_compatible_bot import OpenAICompatibleBot
def add_openai_compatible_support(bot_instance):
"""
Dynamically add OpenAI-compatible tool calling support to a bot instance.
This allows any bot to gain tool calling capability without modifying its code,
as long as it uses OpenAI-compatible API format.
Note: Some bots like ZHIPUAIBot have native tool calling support and don't need enhancement.
"""
if hasattr(bot_instance, 'call_with_tools'):
# Bot already has tool calling support (e.g., ZHIPUAIBot)
logger.debug(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
return bot_instance
# Create a temporary mixin class that combines the bot with OpenAI compatibility
class EnhancedBot(bot_instance.__class__, OpenAICompatibleBot):
"""Dynamically enhanced bot with OpenAI-compatible tool calling"""
def get_api_config(self):
"""
Infer API config from common configuration patterns.
Most OpenAI-compatible bots use similar configuration.
"""
from config import conf
return {
'api_key': conf().get("open_ai_api_key"),
'api_base': conf().get("open_ai_api_base"),
'model': conf().get("model", "gpt-3.5-turbo"),
'default_temperature': conf().get("temperature", 0.9),
'default_top_p': conf().get("top_p", 1.0),
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
'default_presence_penalty': conf().get("presence_penalty", 0.0),
}
# Change the bot's class to the enhanced version
bot_instance.__class__ = EnhancedBot
logger.info(
f"[AgentBridge] Enhanced {bot_instance.__class__.__bases__[0].__name__} with OpenAI-compatible tool calling")
return bot_instance
class AgentLLMModel(LLMModel):
"""
LLM Model adapter that uses COW's existing bot infrastructure
"""
_MODEL_BOT_TYPE_MAP = {
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
"xunfei": const.XUNFEI, const.QWEN: const.QWEN,
const.MODELSCOPE: const.MODELSCOPE,
}
_MODEL_PREFIX_MAP = [
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
("doubao", const.DOUBAO), ("deepseek", const.DEEPSEEK),
]
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
from config import conf
super().__init__(model=conf().get("model", const.GPT_41))
self.bridge = bridge
self.bot_type = bot_type
self._bot = None
self._bot_model = None
@property
def model(self):
from config import conf
return conf().get("model", const.GPT_41)
@model.setter
def model(self, value):
pass
def _resolve_bot_type(self, model_name: str) -> str:
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
from config import conf
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
return const.LINKAI
# Support custom bot type configuration
configured_bot_type = conf().get("bot_type")
if configured_bot_type:
return configured_bot_type
if not model_name or not isinstance(model_name, str):
return const.OPENAI
if model_name in self._MODEL_BOT_TYPE_MAP:
return self._MODEL_BOT_TYPE_MAP[model_name]
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
return const.MiniMax
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
return const.QWEN_DASHSCOPE
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
return const.MOONSHOT
if conf().get("bot_type") == "modelscope":
return const.MODELSCOPE
for prefix, btype in self._MODEL_PREFIX_MAP:
if model_name.startswith(prefix):
return btype
return const.OPENAI
@property
def bot(self):
"""Lazy load the bot, re-create when model changes"""
from models.bot_factory import create_bot
cur_model = self.model
if self._bot is None or self._bot_model != cur_model:
bot_type = self._resolve_bot_type(cur_model)
self._bot = create_bot(bot_type)
self._bot = add_openai_compatible_support(self._bot)
self._bot_model = cur_model
return self._bot
def call(self, request: LLMRequest):
"""
Call the model using COW's bot infrastructure
"""
try:
# For non-streaming calls, we'll use the existing reply method
# This is a simplified implementation
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled call if available
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': False,
'model': self.model # Pass model parameter
}
# Only pass max_tokens if it's explicitly set
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
if system_prompt:
kwargs['system'] = system_prompt
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
response = self.bot.call_with_tools(**kwargs)
return self._format_response(response)
else:
# Fallback to regular call
# This would need to be implemented based on your specific needs
raise NotImplementedError("Regular call not implemented yet")
except Exception as e:
logger.error(f"AgentLLMModel call error: {e}")
raise
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming using COW's bot infrastructure
"""
try:
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled streaming call if available
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
# Build kwargs for call_with_tools
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': True,
'model': self.model # Pass model parameter
}
# Only pass max_tokens if explicitly set, let the bot use its default
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
# Add system prompt if present
if system_prompt:
kwargs['system'] = system_prompt
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
stream = self.bot.call_with_tools(**kwargs)
# Convert stream format to our expected format
for chunk in stream:
yield self._format_stream_chunk(chunk)
else:
bot_type = type(self.bot).__name__
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
except Exception as e:
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
raise
def _format_response(self, response):
"""Format Claude response to our expected format"""
# This would need to be implemented based on Claude's response format
return response
def _format_stream_chunk(self, chunk):
"""Format Claude stream chunk to our expected format"""
# This would need to be implemented based on Claude's stream format
return chunk
class AgentBridge:
"""
Bridge class that integrates super Agent with COW
Manages multiple agent instances per session for conversation isolation
"""
def __init__(self, bridge: Bridge):
self.bridge = bridge
self.agents = {} # session_id -> Agent instance mapping
self.default_agent = None # For backward compatibility (no session_id)
self.agent: Optional[Agent] = None
self.scheduler_initialized = False
# Create helper instances
self.initializer = AgentInitializer(bridge, self)
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
"""
Create the super agent with COW integration
Args:
system_prompt: System prompt
tools: List of tools (optional)
**kwargs: Additional agent parameters
Returns:
Agent instance
"""
# Create LLM model that uses COW's bot infrastructure
model = AgentLLMModel(self.bridge)
# Default tools if none provided
if tools is None:
# Use ToolManager to load all available tools
from agent.tools import ToolManager
tool_manager = ToolManager()
tool_manager.load_tools()
tools = []
workspace_dir = kwargs.get("workspace_dir")
for tool_name in tool_manager.tool_classes.keys():
try:
tool = tool_manager.create_tool(tool_name)
if tool:
if workspace_dir and hasattr(tool, 'cwd'):
tool.cwd = workspace_dir
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
# Create agent instance
agent = Agent(
system_prompt=system_prompt,
description=kwargs.get("description", "AI Super Agent"),
model=model,
tools=tools,
max_steps=kwargs.get("max_steps", 15),
output_mode=kwargs.get("output_mode", "logger"),
workspace_dir=kwargs.get("workspace_dir"),
skill_manager=kwargs.get("skill_manager"),
enable_skills=kwargs.get("enable_skills", True),
memory_manager=kwargs.get("memory_manager"),
max_context_tokens=kwargs.get("max_context_tokens"),
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
runtime_info=kwargs.get("runtime_info"),
)
# Log skill loading details
if agent.skill_manager:
logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills")
return agent
def get_agent(self, session_id: str = None) -> Optional[Agent]:
"""
Get agent instance for the given session
Args:
session_id: Session identifier (e.g., user_id). If None, returns default agent.
Returns:
Agent instance for this session
"""
# If no session_id, use default agent (backward compatibility)
if session_id is None:
if self.default_agent is None:
self._init_default_agent()
return self.default_agent
# Check if agent exists for this session
if session_id not in self.agents:
self._init_agent_for_session(session_id)
return self.agents[session_id]
def _init_default_agent(self):
"""Initialize default super agent"""
agent = self.initializer.initialize_agent(session_id=None)
self.default_agent = agent
def _init_agent_for_session(self, session_id: str):
"""Initialize agent for a specific session"""
agent = self.initializer.initialize_agent(session_id=session_id)
self.agents[session_id] = agent
def agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to reply to a query
Args:
query: User query
context: COW context (optional, contains session_id for user isolation)
on_event: Event callback (optional)
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
session_id = None
agent = None
try:
# Extract session_id from context for user isolation
if context:
session_id = context.kwargs.get("session_id") or context.get("session_id")
# Get agent for this session (will auto-initialize if needed)
agent = self.get_agent(session_id=session_id)
if not agent:
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
# Create event handler for logging and channel communication
event_handler = AgentEventHandler(context=context, original_callback=on_event)
# Filter tools based on context
original_tools = agent.tools
filtered_tools = original_tools
# If this is a scheduled task execution, exclude scheduler tool to prevent recursion
if context and context.get("is_scheduled_task"):
filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"]
agent.tools = filtered_tools
logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)")
else:
# Attach context to scheduler tool if present
if context and agent.tools:
for tool in agent.tools:
if tool.name == "scheduler":
try:
from agent.tools.scheduler.integration import attach_scheduler_to_tool
attach_scheduler_to_tool(tool, context)
except Exception as e:
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
break
# Pass context metadata to model for downstream API requests
if context and hasattr(agent, 'model'):
agent.model.channel_type = context.get("channel_type", "")
agent.model.session_id = session_id or ""
# Store session_id on agent so executor can clear DB on fatal errors
agent._current_session_id = session_id
try:
# Use agent's run_stream method with event handler
response = agent.run_stream(
user_message=query,
on_event=event_handler.handle_event,
clear_history=clear_history
)
finally:
# Restore original tools
if context and context.get("is_scheduled_task"):
agent.tools = original_tools
# Log execution summary
event_handler.log_summary()
# Persist new messages generated during this run
if session_id:
channel_type = (context.get("channel_type") or "") if context else ""
new_messages = getattr(agent, '_last_run_new_messages', [])
if new_messages:
self._persist_messages(session_id, list(new_messages), channel_type)
else:
with agent.messages_lock:
msg_count = len(agent.messages)
if msg_count == 0:
try:
from agent.memory import get_conversation_store
get_conversation_store().clear_session(session_id)
logger.info(f"[AgentBridge] Cleared DB for recovered session: {session_id}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}")
# Check if there are files to send (from read tool)
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
files_to_send = agent.stream_executor.files_to_send
if files_to_send:
# Send the first file (for now, handle one file at a time)
file_info = files_to_send[0]
logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}")
# Clear files_to_send for next request
agent.stream_executor.files_to_send = []
# Return file reply based on file type
return self._create_file_reply(file_info, response, context)
return Reply(ReplyType.TEXT, response)
except Exception as e:
logger.error(f"Agent reply error: {e}")
# If the agent cleared its messages due to format error / overflow,
# also purge the DB so the next request starts clean.
if session_id and agent:
try:
with agent.messages_lock:
msg_count = len(agent.messages)
if msg_count == 0:
from agent.memory import get_conversation_store
get_conversation_store().clear_session(session_id)
logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}")
except Exception as db_err:
logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}")
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
"""
Create a reply for sending files
Args:
file_info: File metadata from read tool
text_response: Text response from agent
context: Context object
Returns:
Reply object for file sending
"""
file_type = file_info.get("file_type", "file")
file_path = file_info.get("path")
# For images, use IMAGE_URL type (channel will handle upload)
if file_type == "image":
# Convert local path to file:// URL for channel processing
file_url = f"file://{file_path}"
logger.info(f"[AgentBridge] Sending image: {file_url}")
reply = Reply(ReplyType.IMAGE_URL, file_url)
# Attach text message if present (for channels that support text+image)
if text_response:
reply.text_content = text_response # Store accompanying text
return reply
# For all file types (document, video, audio), use FILE type
if file_type in ["document", "video", "audio"]:
file_url = f"file://{file_path}"
logger.info(f"[AgentBridge] Sending {file_type}: {file_url}")
reply = Reply(ReplyType.FILE, file_url)
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
# Attach text message if present
if text_response:
reply.text_content = text_response
return reply
# For other unknown file types, return text with file info
message = text_response or file_info.get("message", "文件已准备")
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
return Reply(ReplyType.TEXT, message)
def _migrate_config_to_env(self, workspace_root: str):
"""
Migrate API keys from config.json to .env file if not already set
Args:
workspace_root: Workspace directory path (not used, kept for compatibility)
"""
from config import conf
import os
# Mapping from config.json keys to environment variable names
key_mapping = {
"open_ai_api_key": "OPENAI_API_KEY",
"open_ai_api_base": "OPENAI_API_BASE",
"gemini_api_key": "GEMINI_API_KEY",
"claude_api_key": "CLAUDE_API_KEY",
"linkai_api_key": "LINKAI_API_KEY",
}
# Use fixed secure location for .env file
env_file = expand_path("~/.cow/.env")
# Read existing env vars from .env file
existing_env_vars = {}
if os.path.exists(env_file):
try:
with open(env_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, _ = line.split('=', 1)
existing_env_vars[key.strip()] = True
except Exception as e:
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
# Check which keys need to be migrated
keys_to_migrate = {}
for config_key, env_key in key_mapping.items():
# Skip if already in .env file
if env_key in existing_env_vars:
continue
# Get value from config.json
value = conf().get(config_key, "")
if value and value.strip(): # Only migrate non-empty values
keys_to_migrate[env_key] = value.strip()
# Log summary if there are keys to skip
if existing_env_vars:
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
# Write new keys to .env file
if keys_to_migrate:
try:
# Ensure ~/.cow directory and .env file exist
env_dir = os.path.dirname(env_file)
if not os.path.exists(env_dir):
os.makedirs(env_dir, exist_ok=True)
if not os.path.exists(env_file):
open(env_file, 'a').close()
# Append new keys
with open(env_file, 'a', encoding='utf-8') as f:
f.write('\n# Auto-migrated from config.json\n')
for key, value in keys_to_migrate.items():
f.write(f'{key}={value}\n')
# Also set in current process
os.environ[key] = value
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
except Exception as e:
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
def _persist_messages(
self, session_id: str, new_messages: list, channel_type: str = ""
) -> None:
"""
Persist new messages to the conversation store after each agent run.
Failures are logged but never propagate — they must not interrupt replies.
"""
if not new_messages:
return
try:
from config import conf
if not conf().get("conversation_persistence", True):
return
except Exception:
pass
try:
from agent.memory import get_conversation_store
get_conversation_store().append_messages(
session_id, new_messages, channel_type=channel_type
)
except Exception as e:
logger.warning(
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
)
def clear_session(self, session_id: str):
"""
Clear a specific session's agent and conversation history
Args:
session_id: Session identifier to clear
"""
if session_id in self.agents:
logger.info(f"[AgentBridge] Clearing session: {session_id}")
del self.agents[session_id]
def clear_all_sessions(self):
"""Clear all agent sessions"""
logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)")
self.agents.clear()
self.default_agent = None
def refresh_all_skills(self) -> int:
"""
Refresh skills and conditional tools in all agent instances after
environment variable changes. This allows hot-reload without restarting.
Returns:
Number of agent instances refreshed
"""
import os
from dotenv import load_dotenv
from config import conf
# Reload environment variables from .env file
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
env_file = os.path.join(workspace_root, '.env')
if os.path.exists(env_file):
load_dotenv(env_file, override=True)
logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
refreshed_count = 0
# Collect all agent instances to refresh
agents_to_refresh = []
if self.default_agent:
agents_to_refresh.append(("default", self.default_agent))
for session_id, agent in self.agents.items():
agents_to_refresh.append((session_id, agent))
for label, agent in agents_to_refresh:
# Refresh skills
if hasattr(agent, 'skill_manager') and agent.skill_manager:
agent.skill_manager.refresh_skills()
# Refresh conditional tools (e.g. web_search depends on API keys)
self._refresh_conditional_tools(agent)
refreshed_count += 1
if refreshed_count > 0:
logger.info(f"[AgentBridge] Refreshed skills & tools in {refreshed_count} agent instance(s)")
return refreshed_count
@staticmethod
def _refresh_conditional_tools(agent):
"""
Add or remove conditional tools based on current environment variables.
For example, web_search should only be present when BOCHA_API_KEY or
LINKAI_API_KEY is set.
"""
try:
from agent.tools.web_search.web_search import WebSearch
has_tool = any(t.name == "web_search" for t in agent.tools)
available = WebSearch.is_available()
if available and not has_tool:
# API key was added - inject the tool
tool = WebSearch()
tool.model = agent.model
agent.tools.append(tool)
logger.info("[AgentBridge] web_search tool added (API key now available)")
elif not available and has_tool:
# API key was removed - remove the tool
agent.tools = [t for t in agent.tools if t.name != "web_search"]
logger.info("[AgentBridge] web_search tool removed (API key no longer available)")
except Exception as e:
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")

View File

@@ -0,0 +1,115 @@
"""
Agent Event Handler - Handles agent events and thinking process output
"""
from common.log import logger
class AgentEventHandler:
"""
Handles agent events and optionally sends intermediate messages to channel
"""
def __init__(self, context=None, original_callback=None):
"""
Initialize event handler
Args:
context: COW context (for accessing channel)
original_callback: Original event callback to chain
"""
self.context = context
self.original_callback = original_callback
# Get channel for sending intermediate messages
self.channel = None
if context:
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
# Track current thinking for channel output
self.current_thinking = ""
self.turn_number = 0
def handle_event(self, event):
"""
Main event handler
Args:
event: Event dict with type and data
"""
event_type = event.get("type")
data = event.get("data", {})
# Dispatch to specific handlers
if event_type == "turn_start":
self._handle_turn_start(data)
elif event_type == "message_update":
self._handle_message_update(data)
elif event_type == "message_end":
self._handle_message_end(data)
elif event_type == "tool_execution_start":
self._handle_tool_execution_start(data)
elif event_type == "tool_execution_end":
self._handle_tool_execution_end(data)
# Call original callback if provided
if self.original_callback:
self.original_callback(event)
def _handle_turn_start(self, data):
"""Handle turn start event"""
self.turn_number = data.get("turn", 0)
self.has_tool_calls_in_turn = False
self.current_thinking = ""
def _handle_message_update(self, data):
"""Handle message update event (streaming text)"""
delta = data.get("delta", "")
self.current_thinking += delta
def _handle_message_end(self, data):
"""Handle message end event"""
tool_calls = data.get("tool_calls", [])
# Only send thinking process if followed by tool calls
if tool_calls:
if self.current_thinking.strip():
logger.info(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
# Send thinking process to channel
self._send_to_channel(f"{self.current_thinking.strip()}")
else:
# No tool calls = final response (logged at agent_stream level)
if self.current_thinking.strip():
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
self.current_thinking = ""
def _handle_tool_execution_start(self, data):
"""Handle tool execution start event - logged by agent_stream.py"""
pass
def _handle_tool_execution_end(self, data):
"""Handle tool execution end event - logged by agent_stream.py"""
pass
def _send_to_channel(self, message):
"""
Try to send intermediate message to channel.
Skipped in SSE mode because thinking text is already streamed via on_event.
"""
if self.context and self.context.get("on_event"):
return
if self.channel:
try:
from bridge.reply import Reply, ReplyType
reply = Reply(ReplyType.TEXT, message)
self.channel._send(reply, self.context)
except Exception as e:
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
def log_summary(self):
"""Log execution summary - simplified"""
# Summary removed as per user request
# Real-time logging during execution is sufficient
pass

584
bridge/agent_initializer.py Normal file
View File

@@ -0,0 +1,584 @@
"""
Agent Initializer - Handles agent initialization logic
"""
import os
import asyncio
import datetime
import time
from typing import Optional, List
from agent.protocol import Agent
from agent.tools import ToolManager
from common.log import logger
from common.utils import expand_path
class AgentInitializer:
"""
Handles agent initialization including:
- Workspace setup
- Memory system initialization
- Tool loading
- System prompt building
"""
def __init__(self, bridge, agent_bridge):
"""
Initialize agent initializer
Args:
bridge: COW bridge instance
agent_bridge: AgentBridge instance (for create_agent method)
"""
self.bridge = bridge
self.agent_bridge = agent_bridge
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
"""
Initialize agent for a session
Args:
session_id: Session ID (None for default agent)
Returns:
Initialized agent instance
"""
from config import conf
# Get workspace from config
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
# Migrate API keys
self._migrate_config_to_env(workspace_root)
# Load environment variables
self._load_env_file()
# Initialize workspace
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
workspace_files = ensure_workspace(workspace_root, create_templates=True)
if session_id is None:
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
# Setup memory system
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
# Load tools
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
# Initialize scheduler if needed
self._initialize_scheduler(tools, session_id)
# Load context files
context_files = load_context_files(workspace_root)
# Initialize skill manager
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
# Build system prompt
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
runtime_info = self._get_runtime_info(workspace_root)
system_prompt = prompt_builder.build(
tools=tools,
context_files=context_files,
skill_manager=skill_manager,
memory_manager=memory_manager,
runtime_info=runtime_info,
)
# Get cost control parameters
from config import conf
max_steps = conf().get("agent_max_steps", 20)
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
# Create agent
agent = self.agent_bridge.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=max_steps,
output_mode="logger",
workspace_dir=workspace_root,
skill_manager=skill_manager,
enable_skills=True,
max_context_tokens=max_context_tokens,
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
)
# Attach memory manager and share LLM model for summarization
if memory_manager:
agent.memory_manager = memory_manager
if hasattr(agent, 'model') and agent.model:
memory_manager.flush_manager.llm_model = agent.model
# Restore persisted conversation history for this session
if session_id:
self._restore_conversation_history(agent, session_id)
# Start daily memory flush timer (once, on first agent init regardless of session)
self._start_daily_flush_timer()
return agent
def _restore_conversation_history(self, agent, session_id: str) -> None:
"""
Load persisted conversation messages from SQLite and inject them
into the agent's in-memory message list.
Only user text and assistant text are restored. Tool call chains
(tool_use / tool_result) are stripped out because:
1. They are intermediate process, the value is already in the final
assistant text reply.
2. They consume massive context tokens (often 80%+ of history).
3. Different models have incompatible tool message formats, so
restoring tool chains across model switches causes 400 errors.
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
"""
from config import conf
if not conf().get("conversation_persistence", True):
return
try:
from agent.memory import get_conversation_store
store = get_conversation_store()
max_turns = conf().get("agent_max_context_turns", 20)
restore_turns = max(3, max_turns // 6)
saved = store.load_messages(session_id, max_turns=restore_turns)
if saved:
filtered = self._filter_text_only_messages(saved)
if filtered:
with agent.messages_lock:
agent.messages = filtered
logger.debug(
f"[AgentInitializer] Restored {len(filtered)} text messages "
f"(from {len(saved)} total, {restore_turns} turns cap) "
f"for session={session_id}"
)
except Exception as e:
logger.warning(
f"[AgentInitializer] Failed to restore conversation history for "
f"session={session_id}: {e}"
)
@staticmethod
def _filter_text_only_messages(messages: list) -> list:
"""
Extract clean user/assistant turn pairs from raw message history.
Groups messages into turns (each starting with a real user query),
then keeps only:
- The first user text in each turn (the actual user input)
- The last assistant text in each turn (the final answer)
All tool_use, tool_result, intermediate assistant thoughts, and
internal hint messages injected by the agent loop are discarded.
"""
def _extract_text(content) -> str:
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = [
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
]
return "\n".join(p for p in parts if p).strip()
return ""
def _is_real_user_msg(msg: dict) -> bool:
"""True for actual user input, False for tool_result or internal hints."""
if msg.get("role") != "user":
return False
content = msg.get("content")
if isinstance(content, list):
has_tool_result = any(
isinstance(b, dict) and b.get("type") == "tool_result"
for b in content
)
if has_tool_result:
return False
text = _extract_text(content)
return bool(text)
# Group into turns: each turn starts with a real user message
turns = []
current_turn = None
for msg in messages:
if _is_real_user_msg(msg):
if current_turn is not None:
turns.append(current_turn)
current_turn = {"user": msg, "assistants": []}
elif current_turn is not None and msg.get("role") == "assistant":
text = _extract_text(msg.get("content"))
if text:
current_turn["assistants"].append(text)
if current_turn is not None:
turns.append(current_turn)
# Build result: one user msg + one assistant msg per turn
filtered = []
for turn in turns:
user_text = _extract_text(turn["user"].get("content"))
if not user_text:
continue
filtered.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
if turn["assistants"]:
final_reply = turn["assistants"][-1]
filtered.append({
"role": "assistant",
"content": [{"type": "text", "text": final_reply}]
})
return filtered
def _load_env_file(self):
"""Load environment variables from .env file"""
env_file = expand_path("~/.cow/.env")
if os.path.exists(env_file):
try:
from dotenv import load_dotenv
load_dotenv(env_file, override=True)
except ImportError:
logger.warning("[AgentInitializer] python-dotenv not installed")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
"""
Setup memory system
Returns:
(memory_manager, memory_tools) tuple
"""
memory_manager = None
memory_tools = []
try:
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
from agent.tools import MemorySearchTool, MemoryGetTool
from config import conf
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
embedding_provider = None
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
if session_id is None:
logger.info("[AgentInitializer] OpenAI embedding initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
if embedding_provider is None:
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="linkai",
model="text-embedding-3-small",
api_key=linkai_api_key,
api_base=f"{linkai_api_base}/v1"
)
if session_id is None:
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
except Exception as e:
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
# Create memory manager
memory_config = MemoryConfig(workspace_root=workspace_root)
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# Sync memory
self._sync_memory(memory_manager, session_id)
# Create memory tools
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
if session_id is None:
logger.info("[AgentInitializer] Memory system initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
return memory_manager, memory_tools
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
"""Sync memory database"""
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
raise RuntimeError("Event loop is closed")
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
if loop.is_running():
asyncio.create_task(memory_manager.sync())
else:
loop.run_until_complete(memory_manager.sync())
except Exception as e:
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
"""Load all tools"""
tool_manager = ToolManager()
tool_manager.load_tools()
tools = []
file_config = {
"cwd": workspace_root,
"memory_manager": memory_manager
} if memory_manager else {"cwd": workspace_root}
for tool_name in tool_manager.tool_classes.keys():
try:
# Skip web_search if no API key is available
if tool_name == "web_search":
from agent.tools.web_search.web_search import WebSearch
if not WebSearch.is_available():
logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY")
continue
# Special handling for EnvConfig tool
if tool_name == "env_config":
from agent.tools import EnvConfig
tool = EnvConfig({"agent_bridge": self.agent_bridge})
else:
tool = tool_manager.create_tool(tool_name)
if tool:
# Apply workspace config to file operation tools
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch', 'send', 'browser']:
tool.config = file_config
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
if 'memory_manager' in file_config:
tool.memory_manager = file_config['memory_manager']
tools.append(tool)
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
# Add memory tools
if memory_tools:
tools.extend(memory_tools)
if session_id is None:
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
if session_id is None:
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
return tools
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
"""Initialize scheduler service if needed"""
if not self.agent_bridge.scheduler_initialized:
try:
from agent.tools.scheduler.integration import init_scheduler
if init_scheduler(self.agent_bridge):
self.agent_bridge.scheduler_initialized = True
if session_id is None:
logger.info("[AgentInitializer] Scheduler service initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
# Inject scheduler dependencies
if self.agent_bridge.scheduler_initialized:
try:
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
from agent.tools import SchedulerTool
from config import conf
task_store = get_task_store()
scheduler_service = get_scheduler_service()
for tool in tools:
if isinstance(tool, SchedulerTool):
tool.task_store = task_store
tool.scheduler_service = scheduler_service
if not tool.config:
tool.config = {}
raw_ct = conf().get("channel_type", "unknown")
if isinstance(raw_ct, list):
ct = raw_ct[0] if raw_ct else "unknown"
elif isinstance(raw_ct, str) and "," in raw_ct:
ct = raw_ct.split(",")[0].strip()
else:
ct = raw_ct
tool.config["channel_type"] = ct
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
"""Initialize skill manager"""
try:
from agent.skills import SkillManager
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
return skill_manager
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
return None
def _get_runtime_info(self, workspace_root: str):
"""Get runtime information with dynamic time support"""
from config import conf
def get_current_time():
"""Get current time dynamically - called each time system prompt is accessed"""
now = datetime.datetime.now()
# Get timezone info
try:
offset = -time.timezone if not time.daylight else -time.altzone
hours = offset // 3600
minutes = (offset % 3600) // 60
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
except Exception:
timezone_name = "UTC"
# Chinese weekday mapping
weekday_map = {
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
}
weekday_zh = weekday_map.get(now.strftime("%A"), now.strftime("%A"))
return {
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
'weekday': weekday_zh,
'timezone': timezone_name
}
return {
"model": conf().get("model", "unknown"),
"workspace": workspace_root,
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
"_get_current_time": get_current_time # Dynamic time function
}
def _migrate_config_to_env(self, workspace_root: str):
"""Migrate API keys from config.json to .env file"""
from config import conf
key_mapping = {
"open_ai_api_key": "OPENAI_API_KEY",
"open_ai_api_base": "OPENAI_API_BASE",
"gemini_api_key": "GEMINI_API_KEY",
"claude_api_key": "CLAUDE_API_KEY",
"linkai_api_key": "LINKAI_API_KEY",
}
env_file = expand_path("~/.cow/.env")
# Read existing env vars
existing_env_vars = {}
if os.path.exists(env_file):
try:
with open(env_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, _ = line.split('=', 1)
existing_env_vars[key.strip()] = True
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
# Check which keys need migration
keys_to_migrate = {}
for config_key, env_key in key_mapping.items():
if env_key in existing_env_vars:
continue
value = conf().get(config_key, "")
if value and value.strip():
keys_to_migrate[env_key] = value.strip()
# Write new keys
if keys_to_migrate:
try:
env_dir = os.path.dirname(env_file)
if not os.path.exists(env_dir):
os.makedirs(env_dir, exist_ok=True)
if not os.path.exists(env_file):
open(env_file, 'a').close()
with open(env_file, 'a', encoding='utf-8') as f:
f.write('\n# Auto-migrated from config.json\n')
for key, value in keys_to_migrate.items():
f.write(f'{key}={value}\n')
os.environ[key] = value
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
except Exception as e:
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
def _start_daily_flush_timer(self):
"""Start a background thread that flushes all agents' memory daily at 23:55."""
if getattr(self.agent_bridge, '_daily_flush_started', False):
return
self.agent_bridge._daily_flush_started = True
import threading
def _daily_flush_loop():
while True:
try:
now = datetime.datetime.now()
target = now.replace(hour=23, minute=55, second=0, microsecond=0)
if target <= now:
target += datetime.timedelta(days=1)
wait_seconds = (target - now).total_seconds()
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M')} (in {wait_seconds/3600:.1f}h)")
time.sleep(wait_seconds)
self._flush_all_agents()
except Exception as e:
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
time.sleep(3600)
t = threading.Thread(target=_daily_flush_loop, daemon=True)
t.start()
def _flush_all_agents(self):
"""Flush memory for all active agent sessions."""
agents = []
if self.agent_bridge.default_agent:
agents.append(("default", self.agent_bridge.default_agent))
for sid, agent in self.agent_bridge.agents.items():
agents.append((sid, agent))
if not agents:
return
flushed = 0
for label, agent in agents:
try:
if not agent.memory_manager:
continue
with agent.messages_lock:
messages = list(agent.messages)
if not messages:
continue
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
if result:
flushed += 1
except Exception as e:
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
if flushed:
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")

View File

@@ -1,4 +1,4 @@
from bot.bot_factory import create_bot
from models.bot_factory import create_bot
from bridge.context import Context
from bridge.reply import Reply
from common import const
@@ -13,31 +13,76 @@ from voice.factory import create_voice
class Bridge(object):
def __init__(self):
self.btype = {
"chat": const.CHATGPT,
"chat": const.OPENAI,
"voice_to_text": conf().get("voice_to_text", "openai"),
"text_to_voice": conf().get("text_to_voice", "google"),
"translate": conf().get("translate", "baidu"),
}
model_type = conf().get("model") or const.GPT35
if model_type in ["text-davinci-003"]:
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
self.btype["chat"] = const.CHATGPTONAZURE
if model_type in ["wenxin", "wenxin-4"]:
self.btype["chat"] = const.BAIDU
if model_type in ["xunfei"]:
self.btype["chat"] = const.XUNFEI
if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
self.btype["voice_to_text"] = const.LINKAI
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
self.btype["text_to_voice"] = const.LINKAI
if model_type in ["claude"]:
self.btype["chat"] = const.CLAUDEAI
# 这边取配置的模型
bot_type = conf().get("bot_type")
if bot_type:
self.btype["chat"] = bot_type
else:
model_type = conf().get("model") or const.GPT_41_MINI
# Ensure model_type is string to prevent AttributeError when using startswith()
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
if not isinstance(model_type, str):
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
model_type = str(model_type)
if model_type in ["text-davinci-003"]:
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
self.btype["chat"] = const.CHATGPTONAZURE
if model_type in ["wenxin", "wenxin-4"]:
self.btype["chat"] = const.BAIDU
if model_type in ["xunfei"]:
self.btype["chat"] = const.XUNFEI
if model_type in [const.QWEN]:
self.btype["chat"] = const.QWEN
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
self.btype["chat"] = const.QWEN_DASHSCOPE
# Support Qwen3 and other DashScope models
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
self.btype["chat"] = const.QWEN_DASHSCOPE
if model_type and model_type.startswith("gemini"):
self.btype["chat"] = const.GEMINI
if model_type and model_type.startswith("glm"):
self.btype["chat"] = const.ZHIPU_AI
if model_type and model_type.startswith("claude"):
self.btype["chat"] = const.CLAUDEAPI
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
self.btype["chat"] = const.MOONSHOT
if model_type and model_type.startswith("kimi"):
self.btype["chat"] = const.MOONSHOT
if model_type and model_type.startswith("doubao"):
self.btype["chat"] = const.DOUBAO
if model_type and model_type.startswith("deepseek"):
self.btype["chat"] = const.DEEPSEEK
if model_type in [const.MODELSCOPE]:
self.btype["chat"] = const.MODELSCOPE
# MiniMax models
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
self.btype["chat"] = const.MiniMax
if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
self.btype["voice_to_text"] = const.LINKAI
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
self.btype["text_to_voice"] = const.LINKAI
self.bots = {}
self.chat_bots = {}
self._agent_bridge = None
# 模型对应的接口
def get_bot(self, typename):
if self.bots.get(typename) is None:
logger.info("create bot {} for {}".format(self.btype[typename], typename))
@@ -76,3 +121,29 @@ class Bridge(object):
重置bot路由
"""
self.__init__()
def get_agent_bridge(self):
"""
Get agent bridge for agent-based conversations
"""
if self._agent_bridge is None:
from bridge.agent_bridge import AgentBridge
self._agent_bridge = AgentBridge(self)
return self._agent_bridge
def fetch_agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to handle the query
Args:
query: User query
context: Context object
on_event: Event callback for streaming
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
agent_bridge = self.get_agent_bridge()
return agent_bridge.agent_reply(query, context, on_event, clear_history)

View File

@@ -16,6 +16,8 @@ class ContextType(Enum):
JOIN_GROUP = 20 # 加入群聊
PATPAT = 21 # 拍了拍
FUNCTION = 22 # 函数调用
EXIT_GROUP = 23 #退出
def __str__(self):
return self.name

View File

@@ -11,7 +11,7 @@ class ReplyType(Enum):
VIDEO_URL = 5 # 视频URL
FILE = 6 # 文件
CARD = 7 # 微信名片仅支持ntchat
InviteRoom = 8 # 邀请好友进群
INVITE_ROOM = 8 # 邀请好友进群
INFO = 9
ERROR = 10
TEXT_ = 11 # 强制文本

View File

@@ -5,17 +5,52 @@ Message sending channel abstract class
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import *
from common.log import logger
from config import conf
class Channel(object):
channel_type = ""
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
def __init__(self):
import threading
self._startup_event = threading.Event()
self._startup_error = None
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
def startup(self):
"""
init channel
"""
raise NotImplementedError
def report_startup_success(self):
self._startup_error = None
self._startup_event.set()
def report_startup_error(self, error: str):
self._startup_error = error
self._startup_event.set()
def wait_startup(self, timeout: float = 3) -> (bool, str):
"""
Wait for channel startup result.
Returns (success: bool, error_msg: str).
"""
ready = self._startup_event.wait(timeout=timeout)
if not ready:
return True, ""
if self._startup_error:
return False, self._startup_error
return True, ""
def stop(self):
"""
stop channel gracefully, called before restart
"""
pass
def handle_text(self, msg):
"""
process received msg
@@ -34,7 +69,37 @@ class Channel(object):
raise NotImplementedError
def build_reply_content(self, query, context: Context = None) -> Reply:
return Bridge().fetch_reply_content(query, context)
"""
Build reply content, using agent if enabled in config
"""
# Check if agent mode is enabled
use_agent = conf().get("agent", False)
if use_agent:
try:
logger.info("[Channel] Using agent mode")
# Add channel_type to context if not present
if context and "channel_type" not in context:
context["channel_type"] = self.channel_type
# Read on_event callback injected by the channel (e.g. web SSE)
on_event = context.get("on_event") if context else None
# Use agent bridge to handle the query
return Bridge().fetch_agent_reply(
query=query,
context=context,
on_event=on_event,
clear_history=False
)
except Exception as e:
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
# Fallback to normal mode if agent fails
return Bridge().fetch_reply_content(query, context)
else:
# Normal mode
return Bridge().fetch_reply_content(query, context)
def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file)

View File

@@ -2,43 +2,48 @@
channel factory
"""
from common import const
from .channel import Channel
def create_channel(channel_type):
def create_channel(channel_type) -> Channel:
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
"""
if channel_type == "wx":
from channel.wechat.wechat_channel import WechatChannel
return WechatChannel()
elif channel_type == "wxy":
from channel.wechat.wechaty_channel import WechatyChannel
return WechatyChannel()
elif channel_type == "terminal":
ch = Channel()
if channel_type == "terminal":
from channel.terminal.terminal_channel import TerminalChannel
return TerminalChannel()
ch = TerminalChannel()
elif channel_type == 'web':
from channel.web.web_channel import WebChannel
ch = WebChannel()
elif channel_type == "wechatmp":
from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply=True)
ch = WechatMPChannel(passive_reply=True)
elif channel_type == "wechatmp_service":
from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply=False)
ch = WechatMPChannel(passive_reply=False)
elif channel_type == "wechatcom_app":
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
return WechatComAppChannel()
elif channel_type == "wework":
from channel.wework.wework_channel import WeworkChannel
return WeworkChannel()
ch = WechatComAppChannel()
elif channel_type == const.FEISHU:
from channel.feishu.feishu_channel import FeiShuChanel
return FeiShuChanel()
raise RuntimeError
ch = FeiShuChanel()
elif channel_type == const.DINGTALK:
from channel.dingtalk.dingtalk_channel import DingTalkChanel
ch = DingTalkChanel()
elif channel_type == const.WECOM_BOT:
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
ch = WecomBotChannel()
elif channel_type == const.QQ:
from channel.qq.qq_channel import QQChannel
ch = QQChannel()
elif channel_type in (const.WEIXIN, "wx"):
from channel.weixin.weixin_channel import WeixinChannel
ch = WeixinChannel()
channel_type = const.WEIXIN
else:
raise RuntimeError
ch.channel_type = channel_type
return ch

View File

@@ -17,17 +17,24 @@ try:
except Exception as e:
pass
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel):
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉正在执行的不会被取消
sessions = {} # 用于控制并发每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
def __init__(self):
super().__init__()
# Instance-level attributes so each channel subclass has its own
# independent session queue and lock. Previously these were class-level,
# which caused contexts from one channel (e.g. Feishu) to be consumed
# by another channel's consume() thread (e.g. Web), leading to errors
# like "No request_id found in context".
self.futures = {}
self.sessions = {}
self.lock = threading.Lock()
_thread = threading.Thread(target=self.consume)
_thread.setDaemon(True)
_thread.start()
@@ -36,9 +43,8 @@ class ChatChannel(Channel):
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
# context首次传入时origin_ctype是None,
# 引入的起因是当输入语音时会嵌套生成两个context第一步语音转文本第二步通过文本生成文字回复。
# origin_ctype用于第二步文本回复时判断是否需要匹配前缀如果是私聊的语音就不需要匹配前缀
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
# context首次传入时receiver是None根据类型设置receiver
@@ -63,16 +69,24 @@ class ChatChannel(Channel):
check_contain(group_name, group_name_keyword_white_list),
]
):
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id
if any(
[
group_name in group_chat_in_one_session,
"ALL_GROUP" in group_chat_in_one_session,
]
):
# Check global group_shared_session config first
group_shared_session = conf().get("group_shared_session", True)
if group_shared_session:
# All users in the group share the same session
session_id = group_id
else:
# Check group-specific whitelist (legacy behavior)
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id
if any(
[
group_name in group_chat_in_one_session,
"ALL_GROUP" in group_chat_in_one_session,
]
):
session_id = group_id
else:
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
return None
context["session_id"] = session_id
context["receiver"] = group_id
@@ -84,14 +98,14 @@ class ChatChannel(Channel):
if e_context.is_pass() or context is None:
return context
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped")
logger.debug("[chat_channel]self message skipped")
return None
# 消息内容匹配过程并处理content
if ctype == ContextType.TEXT:
if first_in and "\n- - - - - - -" in content: # 初次匹配 过滤引用消息
logger.debug(content)
logger.debug("[WX]reference query skipped")
logger.debug("[chat_channel]reference query skipped")
return None
nick_name_black_list = conf().get("nick_name_black_list", [])
@@ -109,12 +123,13 @@ class ChatChannel(Channel):
nick_name = context["msg"].actual_user_nickname
if nick_name and nick_name in nick_name_black_list:
# 黑名单过滤
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
return None
logger.info("[WX]receive group at")
logger.info("[chat_channel]receive group at")
if not conf().get("group_at_off", False):
flag = True
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
subtract_res = re.sub(pattern, r"", content)
if isinstance(context["msg"].at_list, list):
@@ -128,13 +143,13 @@ class ChatChannel(Channel):
content = subtract_res
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
nick_name = context["msg"].from_user_nickname
if nick_name and nick_name in nick_name_black_list:
# 黑名单过滤
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
return None
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
@@ -143,9 +158,10 @@ class ChatChannel(Channel):
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass
else:
logger.info("[chat_channel]receive single chat msg, but checkprefix didn't match")
return None
content = content.strip()
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
@@ -157,22 +173,23 @@ class ChatChannel(Channel):
elif context.type == ContextType.VOICE:
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE
return context
def _handle(self, context: Context):
if context is None or not context.content:
return
logger.debug("[WX] ready to handle context: {}".format(context))
logger.debug("[chat_channel] handling context: {}".format(context))
# reply的构建步骤
reply = self._generate_reply(context)
logger.debug("[WX] ready to decorate reply: {}".format(reply))
# reply的包装步骤
reply = self._decorate_reply(context, reply)
logger.debug("[chat_channel] decorating reply: {}".format(reply))
# reply的发送步骤
self._send_reply(context, reply)
# reply的包装步骤
if reply and reply.content:
reply = self._decorate_reply(context, reply)
# reply的发送步骤
self._send_reply(context, reply)
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
e_context = PluginManager().emit_event(
@@ -183,9 +200,7 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass():
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
if e_context.is_break():
context["generate_breaked_by"] = e_context["breaked_by"]
logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
context["channel"] = e_context["channel"]
reply = super().build_reply_content(context.content, context)
@@ -197,7 +212,7 @@ class ChatChannel(Channel):
try:
any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败直接使用mp3对于某些apimp3也可以识别
logger.warning("[WX]any to wav error, use raw path. " + str(e))
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
wav_path = file_path
# 语音识别
reply = super().build_voice_to_text(wav_path)
@@ -208,7 +223,7 @@ class ChatChannel(Channel):
os.remove(wav_path)
except Exception as e:
pass
# logger.warning("[WX]delete temp file error: " + str(e))
# logger.warning("[chat_channel]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT:
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
@@ -226,7 +241,7 @@ class ChatChannel(Channel):
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
pass
else:
logger.warning("[WX] unknown context type: {}".format(context.type))
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
return
return reply
@@ -242,7 +257,7 @@ class ChatChannel(Channel):
desire_rtype = context.get("desire_rtype")
if not e_context.is_pass() and reply and reply.type:
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
logger.error("[WX]reply type not support: " + str(reply.type))
logger.error("[chat_channel]reply type not support: " + str(reply.type))
reply.type = ReplyType.ERROR
reply.content = "不支持发送的消息类型: " + str(reply.type)
@@ -263,10 +278,10 @@ class ChatChannel(Channel):
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
pass
else:
logger.error("[WX] unknown reply type: {}".format(reply.type))
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply
def _send_reply(self, context: Context, reply: Reply):
@@ -279,14 +294,107 @@ class ChatChannel(Channel):
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
# 如果是文本回复,尝试提取并发送图片
if reply.type == ReplyType.TEXT:
self._extract_and_send_images(reply, context)
# 如果是图片回复但带有文本内容,先发文本再发图片
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
# 先发送文本
text_reply = Reply(ReplyType.TEXT, reply.text_content)
self._send(text_reply, context)
# 短暂延迟后发送图片
time.sleep(0.3)
self._send(reply, context)
else:
self._send(reply, context)
def _extract_and_send_images(self, reply: Reply, context: Context):
"""
从文本回复中提取图片/视频URL并单独发送
支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], ![](url), <img src="url">
最多发送5个媒体文件
"""
content = reply.content
media_items = [] # [(url, type), ...]
# 正则提取各种格式的媒体URL
patterns = [
(r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png]
(r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4]
(r'!\[.*?\]\(([^\)]+)\)', 'image'), # ![alt](url) - 默认图片
(r'<img[^>]+src=["\']([^"\']+)["\']', 'image'), # <img src="url">
(r'<video[^>]+src=["\']([^"\']+)["\']', 'video'), # <video src="url">
(r'https?://[^\s]+\.(?:jpg|jpeg|png|gif|webp)', 'image'), # 直接的图片URL
(r'https?://[^\s]+\.(?:mp4|avi|mov|wmv|flv)', 'video'), # 直接的视频URL
]
for pattern, media_type in patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
media_items.append((match, media_type))
# 去重保持顺序并限制最多5个
seen = set()
unique_items = []
for url, mtype in media_items:
if url not in seen:
seen.add(url)
unique_items.append((url, mtype))
media_items = unique_items[:5]
if media_items:
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
# 先发送文本(保持原文本不变)
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
self._send(reply, context)
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
# 然后逐个发送媒体文件
for i, (url, media_type) in enumerate(media_items):
try:
# 判断是本地文件还是URL
if url.startswith(('http://', 'https://')):
# 网络资源
if media_type == 'video':
# 视频使用 FILE 类型发送
media_reply = Reply(ReplyType.FILE, url)
media_reply.file_name = os.path.basename(url)
else:
# 图片使用 IMAGE_URL 类型
media_reply = Reply(ReplyType.IMAGE_URL, url)
elif os.path.exists(url):
# 本地文件
if media_type == 'video':
# 视频使用 FILE 类型,转换为 file:// URL
media_reply = Reply(ReplyType.FILE, f"file://{url}")
media_reply.file_name = os.path.basename(url)
else:
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
else:
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
continue
# 发送媒体文件(添加小延迟避免频率限制)
if i > 0:
time.sleep(0.5)
self._send(media_reply, context)
logger.info(f"[chat_channel] Sent {media_type} {i+1}/{len(media_items)}: {url[:50]}...")
except Exception as e:
logger.error(f"[chat_channel] Failed to send {media_type} {url}: {e}")
else:
# 没有媒体文件,正常发送文本
self._send(reply, context)
def _send(self, reply: Reply, context: Context, retry_cnt=0):
try:
self.send(reply, context)
except Exception as e:
logger.error("[WX] sendMsg error: {}".format(str(e)))
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
if isinstance(e, NotImplementedError):
return
logger.exception(e)
@@ -323,7 +431,7 @@ class ChatChannel(Channel):
if session_id not in self.sessions:
self.sessions[session_id] = [
Dequeue(),
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
@@ -335,24 +443,27 @@ class ChatChannel(Channel):
while True:
with self.lock:
session_ids = list(self.sessions.keys())
for session_id in session_ids:
for session_id in session_ids:
with self.lock:
context_queue, semaphore = self.sessions[session_id]
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[chat_channel] consume context: {}".format(context))
future: Future = handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
with self.lock:
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
with self.lock:
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
time.sleep(0.1)
else:
semaphore.release()
time.sleep(0.2)
# 取消session_id对应的所有任务只能取消排队的消息和已提交线程池但未执行的任务
def cancel_session(self, session_id):

View File

@@ -1,5 +1,5 @@
"""
本类表示聊天消息用于对itchat和wechaty的消息进行统一的封装。
Unified chat message class for different channel implementations.
填好必填项(群聊6个非群聊8个)即可接入ChatChannel并支持插件参考TerminalChannel

View File

@@ -0,0 +1,970 @@
"""
钉钉通道接入
@author huiwen
@Date 2023/11/28
"""
import copy
import json
# -*- coding=utf-8 -*-
import logging
import os
import time
import requests
import dingtalk_stream
from dingtalk_stream import AckMessage
from dingtalk_stream.card_replier import AICardReplier
from dingtalk_stream.card_replier import AICardStatus
from dingtalk_stream.card_replier import CardReplier
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel
from common.utils import expand_path
from channel.dingtalk.dingtalk_message import DingTalkMessage
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.time_check import time_checker
from config import conf
class CustomAICardReplier(CardReplier):
def __init__(self, dingtalk_client, incoming_message):
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
def start(
self,
card_template_id: str,
card_data: dict,
recipients: list = None,
support_forward: bool = True,
) -> str:
"""
AI卡片的创建接口
:param support_forward:
:param recipients:
:param card_template_id:
:param card_data:
:return:
"""
card_data_with_status = copy.deepcopy(card_data)
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
return self.create_and_send_card(
card_template_id,
card_data_with_status,
at_sender=True,
at_all=False,
recipients=recipients,
support_forward=support_forward,
)
# 对 AICardReplier 进行猴子补丁
AICardReplier.start = CustomAICardReplier.start
def _check(func):
def wrapper(self, cmsg: DingTalkMessage):
msgId = cmsg.msg_id
if msgId in self.receivedMsgs:
logger.info("DingTalk message {} already received, ignore".format(msgId))
return
self.receivedMsgs[msgId] = True
create_time = cmsg.create_time # 消息时间戳
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[DingTalk] History message {} skipped".format(msgId))
return
if cmsg.my_msg and not cmsg.is_group:
logger.debug("[DingTalk] My message {} skipped".format(msgId))
return
return func(self, cmsg)
return wrapper
@singleton
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
dingtalk_client_id = conf().get('dingtalk_client_id')
dingtalk_client_secret = conf().get('dingtalk_client_secret')
def setup_logger(self):
# Suppress verbose logs from dingtalk_stream SDK
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
return logging.getLogger("DingTalk")
def __init__(self):
super().__init__()
super(dingtalk_stream.ChatbotHandler, self).__init__()
self.logger = self.setup_logger()
# 历史消息id暂存用于幂等控制
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
self._stream_client = None
self._running = False
self._event_loop = None
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
self.dingtalk_client_id, self.dingtalk_client_secret))
# 无需群校验和前缀
conf()["group_name_white_list"] = ["ALL_GROUP"]
# 单聊无需前缀
conf()["single_chat_prefix"] = [""]
# Access token cache
self._access_token = None
self._access_token_expires_at = 0
# Robot code cache (extracted from incoming messages)
self._robot_code = None
def _open_connection(self, client):
"""
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
Returns (connection_dict, error_str). On success error_str is empty; on failure
connection_dict is None and error_str contains a human-readable message.
"""
try:
resp = requests.post(
"https://api.dingtalk.com/v1.0/gateway/connections/open",
headers={"Content-Type": "application/json", "Accept": "application/json"},
json={
"clientId": client.credential.client_id,
"clientSecret": client.credential.client_secret,
"subscriptions": [{"type": "CALLBACK",
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
"ua": "dingtalk-sdk-python/cow",
"localIp": "",
},
timeout=10,
)
body = resp.json()
if not resp.ok:
code = body.get("code", resp.status_code)
message = body.get("message", resp.reason)
return None, f"open connection failed: [{code}] {message}"
return body, ""
except Exception as e:
return None, f"open connection failed: {e}"
def startup(self):
import asyncio
self.dingtalk_client_id = conf().get('dingtalk_client_id')
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
self._running = True
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential)
self._stream_client = client
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
# Run the connection loop ourselves instead of delegating to client.start(),
# so we can get detailed error messages and respond to stop() quickly.
import urllib.parse as _urlparse
import websockets as _ws
import json as _json
client.pre_start()
_first_connect = True
while self._running:
# Open connection using our own request so we get detailed error info.
connection, err_msg = self._open_connection(client)
if connection is None:
if _first_connect:
logger.warning(f"[DingTalk] {err_msg}")
self.report_startup_error(err_msg)
_first_connect = False
else:
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
# Interruptible sleep: checks _running every 100ms.
for _ in range(100):
if not self._running:
break
time.sleep(0.1)
continue
if _first_connect:
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
self.report_startup_success()
_first_connect = False
else:
logger.info("[DingTalk] Reconnected to DingTalk stream")
# Run the WebSocket session in an asyncio loop.
uri = '%s?ticket=%s' % (
connection['endpoint'],
_urlparse.quote_plus(connection['ticket'])
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._event_loop = loop
try:
async def _session():
async with _ws.connect(uri) as websocket:
client.websocket = websocket
async for raw_message in websocket:
json_message = _json.loads(raw_message)
result = await client.route_message(json_message)
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
break
loop.run_until_complete(_session())
except (KeyboardInterrupt, SystemExit):
logger.info("[DingTalk] Session loop received stop signal, exiting")
break
except Exception as e:
if not self._running:
break
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
for _ in range(30):
if not self._running:
break
time.sleep(0.1)
finally:
self._event_loop = None
try:
loop.close()
except Exception:
pass
logger.info("[DingTalk] Startup loop exited")
def stop(self):
logger.info("[DingTalk] stop() called, setting _running=False")
self._running = False
loop = self._event_loop
if loop and not loop.is_closed():
try:
loop.call_soon_threadsafe(loop.stop)
logger.info("[DingTalk] Sent stop signal to event loop")
except Exception as e:
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
self._stream_client = None
logger.info("[DingTalk] stop() completed")
def get_access_token(self):
"""
获取企业内部应用的 access_token
文档: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
"""
current_time = time.time()
# 如果 token 还没过期,直接返回缓存的 token
if self._access_token and current_time < self._access_token_expires_at:
return self._access_token
# 获取新的 access_token
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
headers = {"Content-Type": "application/json"}
data = {
"appKey": self.dingtalk_client_id,
"appSecret": self.dingtalk_client_secret
}
try:
response = requests.post(url, headers=headers, json=data, timeout=10)
result = response.json()
if response.status_code == 200 and "accessToken" in result:
self._access_token = result["accessToken"]
# Token 有效期为 2 小时,提前 5 分钟刷新
self._access_token_expires_at = current_time + result.get("expireIn", 7200) - 300
logger.info("[DingTalk] Access token refreshed successfully")
return self._access_token
else:
logger.error(f"[DingTalk] Failed to get access token: {result}")
return None
except Exception as e:
logger.error(f"[DingTalk] Error getting access token: {e}")
return None
def send_single_message(self, user_id: str, content: str, robot_code: str) -> bool:
"""
Send message to single user (private chat)
API: https://open.dingtalk.com/document/orgapp/chatbots-send-one-on-one-chat-messages-in-batches
"""
access_token = self.get_access_token()
if not access_token:
logger.error("[DingTalk] Failed to send single message: Access token not available.")
return False
if not robot_code:
logger.error("[DingTalk] Cannot send single message: robot_code is required")
return False
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {
"x-acs-dingtalk-access-token": access_token,
"Content-Type": "application/json"
}
data = {
"msgParam": json.dumps({"content": content}),
"msgKey": "sampleText",
"userIds": [user_id],
"robotCode": robot_code
}
logger.info(f"[DingTalk] Sending single message to user {user_id} with robot_code {robot_code}")
try:
response = requests.post(url, headers=headers, json=data, timeout=10)
result = response.json()
if response.status_code == 200 and result.get("processQueryKey"):
logger.info(f"[DingTalk] Single message sent successfully to {user_id}")
return True
else:
logger.error(f"[DingTalk] Failed to send single message: {result}")
return False
except Exception as e:
logger.error(f"[DingTalk] Error sending single message: {e}")
return False
def send_group_message(self, conversation_id: str, content: str, robot_code: str = None):
"""
主动发送群消息
文档: https://open.dingtalk.com/document/orgapp/the-robot-sends-a-group-message
Args:
conversation_id: 会话ID (openConversationId)
content: 消息内容
robot_code: 机器人编码,默认使用 dingtalk_client_id
"""
access_token = self.get_access_token()
if not access_token:
logger.error("[DingTalk] Cannot send group message: no access token")
return False
# Validate robot_code
if not robot_code:
logger.error("[DingTalk] Cannot send group message: robot_code is required")
return False
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
headers = {
"x-acs-dingtalk-access-token": access_token,
"Content-Type": "application/json"
}
data = {
"msgParam": json.dumps({"content": content}),
"msgKey": "sampleText",
"openConversationId": conversation_id,
"robotCode": robot_code
}
try:
response = requests.post(url, headers=headers, json=data, timeout=10)
result = response.json()
if response.status_code == 200:
logger.info(f"[DingTalk] Group message sent successfully to {conversation_id}")
return True
else:
logger.error(f"[DingTalk] Failed to send group message: {result}")
return False
except Exception as e:
logger.error(f"[DingTalk] Error sending group message: {e}")
return False
def upload_media(self, file_path: str, media_type: str = "image") -> str:
"""
上传媒体文件到钉钉
Args:
file_path: 本地文件路径或URL
media_type: 媒体类型 (image, video, voice, file)
Returns:
media_id如果上传失败返回 None
"""
access_token = self.get_access_token()
if not access_token:
logger.error("[DingTalk] Cannot upload media: no access token")
return None
# 处理 file:// URL
if file_path.startswith("file://"):
file_path = file_path[7:]
# 如果是 HTTP URL先下载
if file_path.startswith("http://") or file_path.startswith("https://"):
try:
import uuid
response = requests.get(file_path, timeout=(5, 60))
if response.status_code != 200:
logger.error(f"[DingTalk] Failed to download file from URL: {file_path}")
return None
# 保存到临时文件
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
temp_file = os.path.join(tmp_dir, file_name)
with open(temp_file, "wb") as f:
f.write(response.content)
file_path = temp_file
logger.info(f"[DingTalk] Downloaded file to {file_path}")
except Exception as e:
logger.error(f"[DingTalk] Error downloading file: {e}")
return None
if not os.path.exists(file_path):
logger.error(f"[DingTalk] File not found: {file_path}")
return None
# 上传到钉钉
# 钉钉上传媒体文件 API: https://open.dingtalk.com/document/orgapp/upload-media-files
url = "https://oapi.dingtalk.com/media/upload"
params = {
"access_token": access_token,
"type": media_type
}
try:
with open(file_path, "rb") as f:
files = {"media": (os.path.basename(file_path), f)}
response = requests.post(url, params=params, files=files, timeout=(5, 60))
result = response.json()
if result.get("errcode") == 0:
media_id = result.get("media_id")
logger.info(f"[DingTalk] Media uploaded successfully, media_id={media_id}")
return media_id
else:
logger.error(f"[DingTalk] Failed to upload media: {result}")
return None
except Exception as e:
logger.error(f"[DingTalk] Error uploading media: {e}")
return None
def send_image_with_media_id(self, access_token: str, media_id: str, incoming_message, is_group: bool) -> bool:
"""
发送图片消息(使用 media_id
Args:
access_token: 访问令牌
media_id: 媒体ID
incoming_message: 钉钉消息对象
is_group: 是否为群聊
Returns:
是否发送成功
"""
headers = {
"x-acs-dingtalk-access-token": access_token,
'Content-Type': 'application/json'
}
msg_param = {
"photoURL": media_id # 钉钉图片消息使用 photoURL 字段
}
body = {
"robotCode": incoming_message.robot_code,
"msgKey": "sampleImageMsg",
"msgParam": json.dumps(msg_param),
}
if is_group:
# 群聊
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
body["openConversationId"] = incoming_message.conversation_id
else:
# 单聊
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
body["userIds"] = [incoming_message.sender_staff_id]
try:
response = requests.post(url=url, headers=headers, json=body, timeout=10)
result = response.json()
logger.info(f"[DingTalk] Image send result: {response.text}")
if response.status_code == 200:
return True
else:
logger.error(f"[DingTalk] Send image error: {response.text}")
return False
except Exception as e:
logger.error(f"[DingTalk] Send image exception: {e}")
return False
def send_image_message(self, receiver: str, media_id: str, is_group: bool, robot_code: str) -> bool:
"""
发送图片消息
Args:
receiver: 接收者ID (user_id 或 conversation_id)
media_id: 媒体ID
is_group: 是否为群聊
robot_code: 机器人编码
Returns:
是否发送成功
"""
access_token = self.get_access_token()
if not access_token:
logger.error("[DingTalk] Cannot send image: no access token")
return False
if not robot_code:
logger.error("[DingTalk] Cannot send image: robot_code is required")
return False
if is_group:
# 发送群聊图片
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
headers = {
"x-acs-dingtalk-access-token": access_token,
"Content-Type": "application/json"
}
data = {
"msgParam": json.dumps({"mediaId": media_id}),
"msgKey": "sampleImageMsg",
"openConversationId": receiver,
"robotCode": robot_code
}
else:
# 发送单聊图片
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {
"x-acs-dingtalk-access-token": access_token,
"Content-Type": "application/json"
}
data = {
"msgParam": json.dumps({"mediaId": media_id}),
"msgKey": "sampleImageMsg",
"userIds": [receiver],
"robotCode": robot_code
}
try:
response = requests.post(url, headers=headers, json=data, timeout=10)
result = response.json()
if response.status_code == 200:
logger.info(f"[DingTalk] Image message sent successfully")
return True
else:
logger.error(f"[DingTalk] Failed to send image message: {result}")
return False
except Exception as e:
logger.error(f"[DingTalk] Error sending image message: {e}")
return False
def get_image_download_url(self, download_code: str) -> str:
"""
获取图片下载地址
返回一个特殊的 URL 格式dingtalk://download/{robot_code}:{download_code}
后续会在 download_image_file 中使用新版 API 下载
"""
# 获取 robot_code
if not hasattr(self, '_robot_code_cache'):
self._robot_code_cache = None
robot_code = self._robot_code_cache
if not robot_code:
logger.error("[DingTalk] robot_code not available for image download")
return None
# 返回一个特殊的 URL包含 robot_code 和 download_code
logger.info(f"[DingTalk] Successfully got image download URL for code: {download_code}")
return f"dingtalk://download/{robot_code}:{download_code}"
async def process(self, callback: dingtalk_stream.CallbackMessage):
try:
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
# 缓存 robot_code用于后续图片下载
if hasattr(incoming_message, 'robot_code'):
self._robot_code_cache = incoming_message.robot_code
# Filter out stale messages from before channel startup (offline backlog)
create_at = getattr(incoming_message, 'create_at', None)
if create_at:
msg_age_s = time.time() - int(create_at) / 1000
if msg_age_s > 60:
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
return AckMessage.STATUS_OK, 'OK'
image_download_handler = self
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
if dingtalk_msg.is_group:
self.handle_group(dingtalk_msg)
else:
self.handle_single(dingtalk_msg)
return AckMessage.STATUS_OK, 'OK'
except Exception as e:
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
@time_checker
@_check
def handle_single(self, cmsg: DingTalkMessage):
# 处理单聊消息
if cmsg.ctype == ContextType.VOICE:
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE_CREATE:
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
else:
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
# 处理文件缓存逻辑
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
# 单聊的 session_id 就是 sender_id
session_id = cmsg.from_user_id
# 如果是单张图片消息,缓存起来
if cmsg.ctype == ContextType.IMAGE:
if hasattr(cmsg, 'image_path') and cmsg.image_path:
file_cache.add(session_id, cmsg.image_path, file_type='image')
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
# 单张图片不直接处理,等待用户提问
return
# 如果是文本消息,检查是否有缓存的文件
if cmsg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
# 将缓存的文件附加到文本消息中
file_refs = []
for file_info in cached_files:
file_path = file_info['path']
file_type = file_info['type']
if file_type == 'image':
file_refs.append(f"[图片: {file_path}]")
elif file_type == 'video':
file_refs.append(f"[视频: {file_path}]")
else:
file_refs.append(f"[文件: {file_path}]")
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
# 清除缓存
file_cache.clear(session_id)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)
@time_checker
@_check
def handle_group(self, cmsg: DingTalkMessage):
# 处理群聊消息
if cmsg.ctype == ContextType.VOICE:
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE_CREATE:
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
else:
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
# 处理文件缓存逻辑
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
# 群聊的 session_id
if conf().get("group_shared_session", True):
session_id = cmsg.other_user_id # conversation_id
else:
session_id = cmsg.from_user_id + "_" + cmsg.other_user_id
# 如果是单张图片消息,缓存起来
if cmsg.ctype == ContextType.IMAGE:
if hasattr(cmsg, 'image_path') and cmsg.image_path:
file_cache.add(session_id, cmsg.image_path, file_type='image')
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
# 单张图片不直接处理,等待用户提问
return
# 如果是文本消息,检查是否有缓存的文件
if cmsg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
# 将缓存的文件附加到文本消息中
file_refs = []
for file_info in cached_files:
file_path = file_info['path']
file_type = file_info['type']
if file_type == 'image':
file_refs.append(f"[图片: {file_path}]")
elif file_type == 'video':
file_refs.append(f"[视频: {file_path}]")
else:
file_refs.append(f"[文件: {file_path}]")
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
# 清除缓存
file_cache.clear(session_id)
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
context['no_need_at'] = True
if context:
self.produce(context)
def send(self, reply: Reply, context: Context):
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
receiver = context["receiver"]
# Check if msg exists (for scheduled tasks, msg might be None)
msg = context.kwargs.get('msg')
if msg is None:
# 定时任务场景:使用主动发送 API
is_group = context.get("isgroup", False)
logger.info(f"[DingTalk] Sending scheduled task message to {receiver} (is_group={is_group})")
# 使用缓存的 robot_code 或配置的值
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
logger.info(f"[DingTalk] Using robot_code: {robot_code}, cached: {self._robot_code}, config: {conf().get('dingtalk_robot_code')}")
if not robot_code:
logger.error(f"[DingTalk] Cannot send scheduled task: robot_code not available. Please send at least one message to the bot first, or configure dingtalk_robot_code in config.json")
return
# 根据是否群聊选择不同的 API
if is_group:
success = self.send_group_message(receiver, reply.content, robot_code)
else:
# 单聊场景:尝试从 context 中获取 dingtalk_sender_staff_id
sender_staff_id = context.get("dingtalk_sender_staff_id")
if not sender_staff_id:
logger.error(f"[DingTalk] Cannot send single chat scheduled message: sender_staff_id not available in context")
return
logger.info(f"[DingTalk] Sending single message to staff_id: {sender_staff_id}")
success = self.send_single_message(sender_staff_id, reply.content, robot_code)
if not success:
logger.error(f"[DingTalk] Failed to send scheduled task message")
return
# 从正常消息中提取并缓存 robot_code
if hasattr(msg, 'robot_code'):
robot_code = msg.robot_code
if robot_code and robot_code != self._robot_code:
self._robot_code = robot_code
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
isgroup = msg.is_group
incoming_message = msg.incoming_message
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
# 处理图片和视频发送
if reply.type == ReplyType.IMAGE_URL:
logger.info(f"[DingTalk] Sending image: {reply.content}")
# 如果有附加的文本内容,先发送文本
if hasattr(reply, 'text_content') and reply.text_content:
self.reply_text(reply.text_content, incoming_message)
import time
time.sleep(0.3) # 短暂延迟,确保文本先到达
media_id = self.upload_media(reply.content, media_type="image")
if media_id:
# 使用主动发送 API 发送图片
access_token = self.get_access_token()
if access_token:
success = self.send_image_with_media_id(
access_token,
media_id,
incoming_message,
isgroup
)
if not success:
logger.error("[DingTalk] Failed to send image message")
self.reply_text("抱歉,图片发送失败", incoming_message)
else:
logger.error("[DingTalk] Cannot get access token")
self.reply_text("抱歉图片发送失败无法获取token", incoming_message)
else:
logger.error("[DingTalk] Failed to upload image")
self.reply_text("抱歉,图片上传失败", incoming_message)
return
elif reply.type == ReplyType.FILE:
# 如果有附加的文本内容,先发送文本
if hasattr(reply, 'text_content') and reply.text_content:
self.reply_text(reply.text_content, incoming_message)
import time
time.sleep(0.3) # 短暂延迟,确保文本先到达
# 判断是否为视频文件
file_path = reply.content
if file_path.startswith("file://"):
file_path = file_path[7:]
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
access_token = self.get_access_token()
if not access_token:
logger.error("[DingTalk] Cannot get access token")
self.reply_text("抱歉文件发送失败无法获取token", incoming_message)
return
if is_video:
logger.info(f"[DingTalk] Sending video: {reply.content}")
media_id = self.upload_media(reply.content, media_type="video")
if media_id:
# 发送视频消息
msg_param = {
"duration": "30", # TODO: 获取实际视频时长
"videoMediaId": media_id,
"videoType": "mp4",
"height": "400",
"width": "600",
}
success = self._send_file_message(
access_token,
incoming_message,
"sampleVideo",
msg_param,
isgroup
)
if not success:
self.reply_text("抱歉,视频发送失败", incoming_message)
else:
logger.error("[DingTalk] Failed to upload video")
self.reply_text("抱歉,视频上传失败", incoming_message)
else:
# 其他文件类型
logger.info(f"[DingTalk] Sending file: {reply.content}")
media_id = self.upload_media(reply.content, media_type="file")
if media_id:
file_name = os.path.basename(file_path)
file_base, file_extension = os.path.splitext(file_name)
msg_param = {
"mediaId": media_id,
"fileName": file_name,
"fileType": file_extension[1:] if file_extension else "file"
}
success = self._send_file_message(
access_token,
incoming_message,
"sampleFile",
msg_param,
isgroup
)
if not success:
self.reply_text("抱歉,文件发送失败", incoming_message)
else:
logger.error("[DingTalk] Failed to upload file")
self.reply_text("抱歉,文件上传失败", incoming_message)
return
# 处理文本消息
elif reply.type == ReplyType.TEXT:
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
if conf().get("dingtalk_card_enabled"):
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
def reply_with_text():
self.reply_text(reply.content, incoming_message)
def reply_with_at_text():
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
def reply_with_ai_markdown():
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
if isgroup:
reply_with_ai_markdown()
reply_with_at_text()
else:
reply_with_ai_markdown()
else:
# 暂不支持其它类型消息回复
reply_with_text()
else:
self.reply_text(reply.content, incoming_message)
return
def _send_file_message(self, access_token: str, incoming_message, msg_key: str, msg_param: dict, is_group: bool) -> bool:
"""
发送文件/视频消息的通用方法
Args:
access_token: 访问令牌
incoming_message: 钉钉消息对象
msg_key: 消息类型 (sampleFile, sampleVideo, sampleAudio)
msg_param: 消息参数
is_group: 是否为群聊
Returns:
是否发送成功
"""
headers = {
"x-acs-dingtalk-access-token": access_token,
'Content-Type': 'application/json'
}
body = {
"robotCode": incoming_message.robot_code,
"msgKey": msg_key,
"msgParam": json.dumps(msg_param),
}
if is_group:
# 群聊
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
body["openConversationId"] = incoming_message.conversation_id
else:
# 单聊
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
body["userIds"] = [incoming_message.sender_staff_id]
try:
response = requests.post(url=url, headers=headers, json=body, timeout=10)
result = response.json()
logger.info(f"[DingTalk] File send result: {response.text}")
if response.status_code == 200:
return True
else:
logger.error(f"[DingTalk] Send file error: {response.text}")
return False
except Exception as e:
logger.error(f"[DingTalk] Send file exception: {e}")
return False
def generate_button_markdown_content(self, context, reply):
image_url = context.kwargs.get("image_url")
promptEn = context.kwargs.get("promptEn")
reply_text = reply.content
button_list = []
markdown_content = f"""
{reply.content}
"""
if image_url is not None and promptEn is not None:
button_list = [
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
]
markdown_content = f"""
{promptEn}
!["图片"]({image_url})
{reply_text}
"""
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
return button_list, markdown_content

View File

@@ -0,0 +1,244 @@
import os
import re
import requests
from dingtalk_stream import ChatbotMessage
from bridge.context import ContextType
from channel.chat_message import ChatMessage
# -*- coding=utf-8 -*-
from common.log import logger
from common.tmp_dir import TmpDir
from common.utils import expand_path
from config import conf
class DingTalkMessage(ChatMessage):
def __init__(self, event: ChatbotMessage, image_download_handler):
super().__init__(event)
self.image_download_handler = image_download_handler
self.msg_id = event.message_id
self.message_type = event.message_type
self.incoming_message = event
self.sender_staff_id = event.sender_staff_id
self.other_user_id = event.conversation_id
self.create_time = event.create_at
self.image_content = event.image_content
self.rich_text_content = event.rich_text_content
self.robot_code = event.robot_code # 机器人编码
if event.conversation_type == "1":
self.is_group = False
else:
self.is_group = True
if self.message_type == "text":
self.ctype = ContextType.TEXT
self.content = event.text.content.strip()
elif self.message_type == "audio":
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
self.content = event.extensions['content']['recognition'].strip()
self.ctype = ContextType.TEXT
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
# 钉钉图片类型或富文本类型消息处理
image_list = event.get_image_list()
if self.message_type == 'picture' and len(image_list) > 0:
# 单张图片消息:下载到工作空间,用于文件缓存
self.ctype = ContextType.IMAGE
download_code = image_list[0]
download_url = image_download_handler.get_image_download_url(download_code)
# 下载到工作空间 tmp 目录
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
image_path = download_image_file(download_url, tmp_dir)
if image_path:
self.content = image_path
self.image_path = image_path # 保存图片路径用于缓存
logger.info(f"[DingTalk] Downloaded single image to {image_path}")
else:
self.content = "[图片下载失败]"
self.image_path = None
elif self.message_type == 'richText' and len(image_list) > 0:
# 富文本消息:下载所有图片并附加到文本中
self.ctype = ContextType.TEXT
# 下载到工作空间 tmp 目录
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
# 提取富文本中的文本内容
text_content = ""
if self.rich_text_content:
# rich_text_content 是一个 RichTextContent 对象,需要从中提取文本
text_list = event.get_text_list()
if text_list:
text_content = "".join(text_list).strip()
# 下载所有图片
image_paths = []
for download_code in image_list:
download_url = image_download_handler.get_image_download_url(download_code)
image_path = download_image_file(download_url, tmp_dir)
if image_path:
image_paths.append(image_path)
# 构建消息内容:文本 + 图片路径
content_parts = []
if text_content:
content_parts.append(text_content)
for img_path in image_paths:
content_parts.append(f"[图片: {img_path}]")
self.content = "\n".join(content_parts) if content_parts else "[富文本消息]"
logger.info(f"[DingTalk] Received richText with {len(image_paths)} image(s): {self.content}")
else:
self.ctype = ContextType.IMAGE
self.content = "[未找到图片]"
logger.debug(f"[DingTalk] messageType: {self.message_type}, imageList isEmpty")
if self.is_group:
self.from_user_id = event.conversation_id
self.actual_user_id = event.sender_id
self.is_at = True
else:
self.from_user_id = event.sender_id
self.actual_user_id = event.sender_id
self.to_user_id = event.chatbot_user_id
self.other_user_nickname = event.conversation_title
def download_image_file(image_url, temp_dir):
"""
下载图片文件
支持两种方式:
1. 普通 HTTP(S) URL
2. 钉钉 downloadCode: dingtalk://download/{download_code}
"""
# 检查临时目录是否存在,如果不存在则创建
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
# 处理钉钉 downloadCode
if image_url.startswith("dingtalk://download/"):
download_code = image_url.replace("dingtalk://download/", "")
logger.info(f"[DingTalk] Downloading image with downloadCode: {download_code[:20]}...")
# 需要从外部传入 access_token这里先用一个临时方案
# 从 config 获取 dingtalk_client_id 和 dingtalk_client_secret
from config import conf
client_id = conf().get("dingtalk_client_id")
client_secret = conf().get("dingtalk_client_secret")
if not client_id or not client_secret:
logger.error("[DingTalk] Missing dingtalk_client_id or dingtalk_client_secret")
return None
# 解析 robot_code 和 download_code
parts = download_code.split(":", 1)
if len(parts) != 2:
logger.error(f"[DingTalk] Invalid download_code format (expected robot_code:download_code): {download_code[:50]}")
return None
robot_code, actual_download_code = parts
# 获取 access_token使用新版 API
token_url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
token_headers = {
"Content-Type": "application/json"
}
token_body = {
"appKey": client_id,
"appSecret": client_secret
}
try:
token_response = requests.post(token_url, json=token_body, headers=token_headers, timeout=10)
if token_response.status_code == 200:
token_data = token_response.json()
access_token = token_data.get("accessToken")
if not access_token:
logger.error(f"[DingTalk] Failed to get access token: {token_data}")
return None
# 获取下载 URL使用新版 API
download_api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
download_headers = {
"x-acs-dingtalk-access-token": access_token,
"Content-Type": "application/json"
}
download_body = {
"downloadCode": actual_download_code,
"robotCode": robot_code
}
download_response = requests.post(download_api_url, json=download_body, headers=download_headers, timeout=10)
if download_response.status_code == 200:
download_data = download_response.json()
download_url = download_data.get("downloadUrl")
if not download_url:
logger.error(f"[DingTalk] No downloadUrl in response: {download_data}")
return None
# 从 downloadUrl 下载实际图片
image_response = requests.get(download_url, stream=True, timeout=60)
if image_response.status_code == 200:
# 生成文件名(使用 download_code 的 hash避免特殊字符
import hashlib
file_hash = hashlib.md5(actual_download_code.encode()).hexdigest()[:16]
file_name = f"{file_hash}.png"
file_path = os.path.join(temp_dir, file_name)
with open(file_path, 'wb') as file:
file.write(image_response.content)
logger.info(f"[DingTalk] Image downloaded successfully: {file_path}")
return file_path
else:
logger.error(f"[DingTalk] Failed to download image from URL: {image_response.status_code}")
return None
else:
logger.error(f"[DingTalk] Failed to get download URL: {download_response.status_code}, {download_response.text}")
return None
else:
logger.error(f"[DingTalk] Failed to get access token: {token_response.status_code}, {token_response.text}")
return None
except Exception as e:
logger.error(f"[DingTalk] Exception downloading image: {e}")
import traceback
logger.error(traceback.format_exc())
return None
# 普通 HTTP(S) URL
else:
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
}
try:
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
if response.status_code == 200:
# 生成文件名
file_name = image_url.split("/")[-1].split("?")[0]
# 将文件保存到临时目录
file_path = os.path.join(temp_dir, file_name)
with open(file_path, 'wb') as file:
file.write(response.content)
return file_path
else:
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
return None
except Exception as e:
logger.error(f"[Dingtalk] Exception downloading image: {e}")
return None

184
channel/feishu/README.md Normal file
View File

@@ -0,0 +1,184 @@
# 飞书Channel使用说明
飞书Channel支持两种事件接收模式可以根据部署环境灵活选择。
## 模式对比
| 模式 | 适用场景 | 优点 | 缺点 |
|------|---------|------|------|
| **webhook** | 生产环境 | 稳定可靠,官方推荐 | 需要公网IP或域名 |
| **websocket** | 本地开发 | 无需公网IP开发便捷 | 需要额外依赖 |
## 配置说明
### 基础配置
`config.json` 中添加以下配置:
```json
{
"channel_type": "feishu",
"feishu_app_id": "cli_xxxxx",
"feishu_app_secret": "your_app_secret",
"feishu_token": "your_verification_token",
"feishu_bot_name": "你的机器人名称",
"feishu_event_mode": "webhook",
"feishu_port": 9891
}
```
### 配置项说明
- `feishu_app_id`: 飞书应用的App ID
- `feishu_app_secret`: 飞书应用的App Secret
- `feishu_token`: 事件订阅的Verification Token
- `feishu_bot_name`: 机器人名称(用于群聊@判断)
- `feishu_event_mode`: 事件接收模式,可选值:
- `"websocket"`: 长连接模式(默认)
- `"webhook"`: HTTP服务器模式
- `feishu_port`: webhook模式下的HTTP服务端口(默认9891)
## 模式一: Webhook模式(推荐生产环境)
### 1. 配置
```json
{
"feishu_event_mode": "webhook",
"feishu_port": 9891
}
```
### 2. 启动服务
```bash
python3 app.py
```
服务将在 `http://0.0.0.0:9891` 启动。
### 3. 配置飞书应用
1. 登录[飞书开放平台](https://open.feishu.cn/)
2. 进入应用详情 -> 事件订阅
3. 选择 **将事件发送至开发者服务器**
4. 填写请求地址: `http://your-domain:9891/`
5. 添加事件: `im.message.receive_v1` (接收消息v2.0)
6. 保存配置
### 4. 注意事项
- 需要有公网IP或域名
- 确保防火墙开放对应端口
- 建议使用HTTPS(需要配置反向代理)
## 模式二: WebSocket模式(推荐本地开发)
### 1. 安装依赖
```bash
pip install lark-oapi
```
### 2. 配置
```json
{
"feishu_event_mode": "websocket"
}
```
### 3. 启动服务
```bash
python3 app.py
```
程序将自动建立与飞书开放平台的长连接。
### 4. 配置飞书应用
1. 登录[飞书开放平台](https://open.feishu.cn/)
2. 进入应用详情 -> 事件订阅
3. 选择 **使用长连接接收事件**
4. 添加事件: `im.message.receive_v1` (接收消息v2.0)
5. 保存配置
### 5. 注意事项
- 无需公网IP
- 需要能访问公网(建立WebSocket连接)
- 每个应用最多50个连接
- 集群模式下消息随机分发到一个客户端
## 平滑迁移
从webhook模式切换到websocket模式(或反向切换):
1. 修改 `config.json` 中的 `feishu_event_mode`
2. 如果切换到websocket模式安装 `lark-oapi` 依赖
3. 重启服务
4. 在飞书开放平台修改事件订阅方式
**重要**: 同一时间只能使用一种模式,否则会导致消息重复接收。
## 消息去重机制
两种模式都使用相同的消息去重机制:
- 使用 `ExpiredDict` 存储已处理的消息ID
- 过期时间: 7.1小时
- 确保消息不会重复处理
## 故障排查
### WebSocket模式连接失败
```
[FeiShu] lark_oapi not installed
```
**解决**: 安装依赖 `pip install lark-oapi`
### SSL证书验证失败
```
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
```
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
**解决**: 程序会自动检测SSL证书验证失败并自动重试禁用证书验证的连接。无需手动配置。
当遇到证书错误时,日志会显示:
```
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
```
这是正常现象,程序会自动处理并继续运行。
### Webhook模式端口被占用
```
Address already in use
```
**解决**: 修改 `feishu_port` 配置或关闭占用端口的进程
### 收不到消息
1. 检查飞书应用的事件订阅配置
2. 确认已添加 `im.message.receive_v1` 事件
3. 检查应用权限: 需要 `im:message` 权限
4. 查看日志中的错误信息
## 开发建议
- **本地开发**: 使用websocket模式快速迭代
- **测试环境**: 可以使用webhook模式 + 内网穿透工具(如ngrok)
- **生产环境**: 使用webhook模式配置正式域名和HTTPS
## 参考文档
- [飞书开放平台 - 事件订阅](https://open.feishu.cn/document/ukTMukTMukTM/uUTNz4SN1MjL1UzM)
- [飞书SDK - Python](https://github.com/larksuite/oapi-sdk-python)

View File

@@ -1,89 +1,478 @@
"""
飞书通道接入
支持两种事件接收模式:
1. webhook模式: 通过HTTP服务器接收事件(需要公网IP)
2. websocket模式: 通过长连接接收事件(本地开发友好)
通过配置项 feishu_event_mode 选择模式: "webhook""websocket"
@author Saboteur7
@Date 2023/11/19
"""
import importlib.util
import json
import logging
import os
import ssl
import threading
# -*- coding=utf-8 -*-
import uuid
import requests
import web
from channel.feishu.feishu_message import FeishuMessage
from bridge.context import Context
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.feishu.feishu_message import FeishuMessage
from common import utils
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from config import conf
from common.expired_dict import ExpiredDict
from bridge.context import ContextType
from channel.chat_channel import ChatChannel, check_prefix
from common import utils
import json
import os
# Suppress verbose logs from Lark SDK
logging.getLogger("Lark").setLevel(logging.WARNING)
URL_VERIFICATION = "url_verification"
# Lazy-check for lark_oapi SDK availability without importing it at module level.
# The full `import lark_oapi` pulls in 10k+ files and takes 4-10s, so we defer
# the actual import to _startup_websocket() where it is needed.
LARK_SDK_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
lark = None # will be populated on first use via _ensure_lark_imported()
def _ensure_lark_imported():
"""Import lark_oapi on first use (takes 4-10s due to 10k+ source files)."""
global lark
if lark is None:
import lark_oapi as _lark
lark = _lark
return lark
@singleton
class FeiShuChanel(ChatChannel):
feishu_app_id = conf().get('feishu_app_id')
feishu_app_secret = conf().get('feishu_app_secret')
feishu_token = conf().get('feishu_token')
feishu_event_mode = conf().get('feishu_event_mode', 'websocket') # webhook 或 websocket
def __init__(self):
super().__init__()
# 历史消息id暂存用于幂等控制
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
self._http_server = None
self._ws_client = None
self._ws_thread = None
self._bot_open_id = None # cached bot open_id for @-mention matching
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
# 无需群校验和前缀
conf()["group_name_white_list"] = ["ALL_GROUP"]
conf()["single_chat_prefix"] = []
conf()["single_chat_prefix"] = [""]
# 验证配置
if self.feishu_event_mode == 'websocket' and not LARK_SDK_AVAILABLE:
logger.error("[FeiShu] websocket mode requires lark_oapi. Please install: pip install lark-oapi")
raise Exception("lark_oapi not installed")
def startup(self):
self.feishu_app_id = conf().get('feishu_app_id')
self.feishu_app_secret = conf().get('feishu_app_secret')
self.feishu_token = conf().get('feishu_token')
self.feishu_event_mode = conf().get('feishu_event_mode', 'websocket')
self._fetch_bot_open_id()
if self.feishu_event_mode == 'websocket':
self._startup_websocket()
else:
self._startup_webhook()
def _fetch_bot_open_id(self):
"""Fetch the bot's own open_id via API so we can match @-mentions without feishu_bot_name."""
try:
access_token = self.fetch_access_token()
if not access_token:
logger.warning("[FeiShu] Cannot fetch bot info: no access_token")
return
headers = {"Authorization": "Bearer " + access_token}
resp = requests.get("https://open.feishu.cn/open-apis/bot/v3/info/", headers=headers, timeout=5)
if resp.status_code == 200:
data = resp.json()
if data.get("code") == 0:
self._bot_open_id = data.get("bot", {}).get("open_id")
logger.info(f"[FeiShu] Bot open_id fetched: {self._bot_open_id}")
else:
logger.warning(f"[FeiShu] Fetch bot info failed: code={data.get('code')}, msg={data.get('msg')}")
except Exception as e:
logger.warning(f"[FeiShu] Fetch bot open_id error: {e}")
def stop(self):
import ctypes
logger.info("[FeiShu] stop() called")
ws_client = self._ws_client
self._ws_client = None
ws_thread = self._ws_thread
self._ws_thread = None
# Interrupt the ws thread first so its blocking start() unblocks
if ws_thread and ws_thread.is_alive():
try:
tid = ws_thread.ident
if tid:
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
)
if res == 1:
logger.info("[FeiShu] Interrupted ws thread via ctypes")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
except Exception as e:
logger.warning(f"[FeiShu] Error interrupting ws thread: {e}")
# lark.ws.Client has no stop() method; thread interruption above is sufficient
if self._http_server:
try:
self._http_server.stop()
logger.info("[FeiShu] HTTP server stopped")
except Exception as e:
logger.warning(f"[FeiShu] Error stopping HTTP server: {e}")
self._http_server = None
logger.info("[FeiShu] stop() completed")
def _startup_webhook(self):
"""启动HTTP服务器接收事件(webhook模式)"""
logger.debug("[FeiShu] Starting in webhook mode...")
urls = (
'/', 'channel.feishu.feishu_channel.FeishuController'
)
app = web.application(urls, globals(), autoreload=False)
port = conf().get("feishu_port", 9891)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
func = web.httpserver.StaticMiddleware(app.wsgifunc())
func = web.httpserver.LogMiddleware(func)
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
self._http_server = server
try:
server.start()
except (KeyboardInterrupt, SystemExit):
server.stop()
def _startup_websocket(self):
"""启动长连接接收事件(websocket模式)"""
_ensure_lark_imported()
logger.debug("[FeiShu] Starting in websocket mode...")
# 创建事件处理器
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
"""处理接收消息事件 v2.0"""
try:
event_dict = json.loads(lark.JSON.marshal(data))
event = event_dict.get("event", {})
msg = event.get("message", {})
# Skip group messages that don't @-mention the bot (reduce log noise)
if msg.get("chat_type") == "group" and not msg.get("mentions") and msg.get("message_type") == "text":
return
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
# 处理消息
self._handle_message_event(event)
except Exception as e:
logger.error(f"[FeiShu] websocket handle message error: {e}", exc_info=True)
# 构建事件分发器
event_handler = lark.EventDispatcherHandler.builder("", "") \
.register_p2_im_message_receive_v1(handle_message_event) \
.build()
def start_client_with_retry():
"""Run ws client in this thread with its own event loop to avoid conflicts."""
import asyncio
import ssl as ssl_module
original_create_default_context = ssl_module.create_default_context
def create_unverified_context(*args, **kwargs):
context = original_create_default_context(*args, **kwargs)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context
# lark_oapi.ws.client captures the event loop at module-import time as a module-
# level global variable. When a previous ws thread is force-killed via ctypes its
# loop may still be marked as "running", which causes the next ws_client.start()
# call (in this new thread) to raise "This event loop is already running".
# Fix: replace the module-level loop with a brand-new, idle loop before starting.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
import lark_oapi.ws.client as _lark_ws_client_mod
_lark_ws_client_mod.loop = loop
except Exception:
pass
startup_error = None
for attempt in range(2):
try:
if attempt == 1:
logger.warning("[FeiShu] Retrying with SSL verification disabled...")
ssl_module.create_default_context = create_unverified_context
ssl_module._create_unverified_context = create_unverified_context
ws_client = lark.ws.Client(
self.feishu_app_id,
self.feishu_app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.WARNING
)
self._ws_client = ws_client
logger.debug("[FeiShu] Websocket client starting...")
ws_client.start()
break
except (SystemExit, KeyboardInterrupt):
logger.info("[FeiShu] Websocket thread received stop signal")
break
except Exception as e:
error_msg = str(e)
is_ssl_error = ("CERTIFICATE_VERIFY_FAILED" in error_msg
or "certificate verify failed" in error_msg.lower())
if is_ssl_error and attempt == 0:
logger.warning(f"[FeiShu] SSL error: {error_msg}, retrying...")
continue
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
startup_error = error_msg
ssl_module.create_default_context = original_create_default_context
break
if startup_error:
self.report_startup_error(startup_error)
try:
loop.close()
except Exception:
pass
logger.info("[FeiShu] Websocket thread exited")
ws_thread = threading.Thread(target=start_client_with_retry, daemon=True)
self._ws_thread = ws_thread
ws_thread.start()
logger.info("[FeiShu] ✅ Websocket thread started, ready to receive messages")
ws_thread.join()
def _is_mention_bot(self, mentions: list) -> bool:
"""Check whether any mention in the list refers to this bot.
Priority:
1. Match by open_id (obtained from /bot/v3/info at startup, no config needed)
2. Fallback to feishu_bot_name config for backward compatibility
3. If neither is available, assume the first mention is the bot (Feishu only
delivers group messages that @-mention the bot, so this is usually correct)
"""
if self._bot_open_id:
return any(
m.get("id", {}).get("open_id") == self._bot_open_id
for m in mentions
)
bot_name = conf().get("feishu_bot_name")
if bot_name:
return any(m.get("name") == bot_name for m in mentions)
# Feishu event subscription only delivers messages that @-mention the bot,
# so reaching here means the bot was indeed mentioned.
return True
def _handle_message_event(self, event: dict):
"""
处理消息事件的核心逻辑
webhook和websocket模式共用此方法
"""
if not event.get("message") or not event.get("sender"):
logger.warning(f"[FeiShu] invalid message, event={event}")
return
msg = event.get("message")
# 幂等判断
msg_id = msg.get("message_id")
if self.receivedMsgs.get(msg_id):
logger.warning(f"[FeiShu] repeat msg filtered, msg_id={msg_id}")
return
self.receivedMsgs[msg_id] = True
# Filter out stale messages from before channel startup (offline backlog)
import time as _time
create_time_ms = msg.get("create_time")
if create_time_ms:
msg_age_s = _time.time() - int(create_time_ms) / 1000
if msg_age_s > 60:
logger.warning(f"[FeiShu] stale msg filtered (age={msg_age_s:.0f}s), msg_id={msg_id}")
return
is_group = False
chat_type = msg.get("chat_type")
if chat_type == "group":
if not msg.get("mentions") and msg.get("message_type") == "text":
# 群聊中未@不响应
return
if msg.get("mentions") and msg.get("message_type") == "text":
if not self._is_mention_bot(msg.get("mentions")):
return
# 群聊
is_group = True
receive_id_type = "chat_id"
elif chat_type == "p2p":
receive_id_type = "open_id"
else:
logger.warning("[FeiShu] message ignore")
return
# 构造飞书消息对象
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=self.fetch_access_token())
if not feishu_msg:
return
# 处理文件缓存逻辑
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
# 获取 session_id用于缓存关联
if is_group:
if conf().get("group_shared_session", True):
session_id = msg.get("chat_id") # 群共享会话
else:
session_id = feishu_msg.from_user_id + "_" + msg.get("chat_id")
else:
session_id = feishu_msg.from_user_id
# 如果是单张图片消息,缓存起来
if feishu_msg.ctype == ContextType.IMAGE:
if hasattr(feishu_msg, 'image_path') and feishu_msg.image_path:
file_cache.add(session_id, feishu_msg.image_path, file_type='image')
logger.info(f"[FeiShu] Image cached for session {session_id}, waiting for user query...")
# 单张图片不直接处理,等待用户提问
return
# 如果是文本消息,检查是否有缓存的文件
if feishu_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
# 将缓存的文件附加到文本消息中
file_refs = []
for file_info in cached_files:
file_path = file_info['path']
file_type = file_info['type']
if file_type == 'image':
file_refs.append(f"[图片: {file_path}]")
elif file_type == 'video':
file_refs.append(f"[视频: {file_path}]")
else:
file_refs.append(f"[文件: {file_path}]")
feishu_msg.content = feishu_msg.content + "\n" + "\n".join(file_refs)
logger.info(f"[FeiShu] Attached {len(cached_files)} cached file(s) to user query")
# 清除缓存
file_cache.clear(session_id)
context = self._compose_context(
feishu_msg.ctype,
feishu_msg.content,
isgroup=is_group,
msg=feishu_msg,
receive_id_type=receive_id_type,
no_need_at=True
)
if context:
self.produce(context)
logger.debug(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
def send(self, reply: Reply, context: Context):
msg = context["msg"]
msg = context.get("msg")
is_group = context["isgroup"]
if msg:
access_token = msg.access_token
else:
access_token = self.fetch_access_token()
headers = {
"Authorization": "Bearer " + msg.access_token,
"Authorization": "Bearer " + access_token,
"Content-Type": "application/json",
}
msg_type = "text"
logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
logger.debug(f"[FeiShu] sending reply, type={context.type}, content={reply.content[:100]}...")
reply_content = reply.content
content_key = "text"
if reply.type == ReplyType.IMAGE_URL:
# 图片上传
reply_content = self._upload_image_url(reply.content, msg.access_token)
reply_content = self._upload_image_url(reply.content, access_token)
if not reply_content:
logger.warning("[FeiShu] upload file failed")
logger.warning("[FeiShu] upload image failed")
return
msg_type = "image"
content_key = "image_key"
if is_group:
# 群聊中直接回复
elif reply.type == ReplyType.FILE:
# 如果有附加的文本内容,先发送文本
if hasattr(reply, 'text_content') and reply.text_content:
logger.info(f"[FeiShu] Sending text before file: {reply.text_content[:50]}...")
text_reply = Reply(ReplyType.TEXT, reply.text_content)
self._send(text_reply, context)
import time
time.sleep(0.3) # 短暂延迟,确保文本先到达
# 判断是否为视频文件
file_path = reply.content
if file_path.startswith("file://"):
file_path = file_path[7:]
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
if is_video:
# 视频上传包含duration信息
upload_data = self._upload_video_url(reply.content, access_token)
if not upload_data or not upload_data.get('file_key'):
logger.warning("[FeiShu] upload video failed")
return
# 视频使用 media 类型(根据官方文档)
# 错误码 230055 说明:上传 mp4 时必须使用 msg_type="media"
msg_type = "media"
reply_content = upload_data # 完整的上传响应数据包含file_key和duration
logger.info(
f"[FeiShu] Sending video: file_key={upload_data.get('file_key')}, duration={upload_data.get('duration')}ms")
content_key = None # 直接序列化整个对象
else:
# 其他文件使用 file 类型
file_key = self._upload_file_url(reply.content, access_token)
if not file_key:
logger.warning("[FeiShu] upload file failed")
return
reply_content = file_key
msg_type = "file"
content_key = "file_key"
# Check if we can reply to an existing message (need msg_id)
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
# Build content JSON
content_json = json.dumps(reply_content, ensure_ascii=False) if content_key is None else json.dumps({content_key: reply_content}, ensure_ascii=False)
logger.debug(f"[FeiShu] Sending message: msg_type={msg_type}, content={content_json[:200]}")
if can_reply:
# 群聊中回复已有消息
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
data = {
"msg_type": msg_type,
"content": json.dumps({content_key: reply_content})
"content": content_json
}
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
else:
# 发送新消息私聊或群聊中无msg_id的情况如定时任务
url = "https://open.feishu.cn/open-apis/im/v1/messages"
params = {"receive_id_type": context.get("receive_id_type")}
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
data = {
"receive_id": context.get("receiver"),
"msg_type": msg_type,
"content": json.dumps({content_key: reply_content})
"content": content_json
}
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
res = res.json()
@@ -92,7 +481,6 @@ class FeiShuChanel(ChatChannel):
else:
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
def fetch_access_token(self) -> str:
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
headers = {
@@ -114,9 +502,35 @@ class FeiShuChanel(ChatChannel):
else:
logger.error(f"[FeiShu] fetch token error, res={response}")
def _upload_image_url(self, img_url, access_token):
logger.debug(f"[WX] start download image, img_url={img_url}")
logger.debug(f"[FeiShu] start process image, img_url={img_url}")
# Check if it's a local file path (file:// protocol)
if img_url.startswith("file://"):
local_path = img_url[7:] # Remove "file://" prefix
logger.info(f"[FeiShu] uploading local file: {local_path}")
if not os.path.exists(local_path):
logger.error(f"[FeiShu] local file not found: {local_path}")
return None
# Upload directly from local file
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
data = {'image_type': 'message'}
headers = {'Authorization': f'Bearer {access_token}'}
with open(local_path, "rb") as file:
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
return response_data.get("data").get("image_key")
else:
logger.error(f"[FeiShu] upload failed: {response_data}")
return None
# Original logic for HTTP URLs
response = requests.get(img_url)
suffix = utils.get_path_suffix(img_url)
temp_name = str(uuid.uuid4()) + "." + suffix
@@ -139,9 +553,295 @@ class FeiShuChanel(ChatChannel):
os.remove(temp_name)
return upload_response.json().get("data").get("image_key")
def _get_video_duration(self, file_path: str) -> int:
"""
获取视频时长(毫秒)
Args:
file_path: 视频文件路径
Returns:
视频时长毫秒如果获取失败返回0
"""
try:
import subprocess
# 使用 ffprobe 获取视频时长
cmd = [
'ffprobe',
'-v', 'error',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
file_path
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0:
duration_seconds = float(result.stdout.strip())
duration_ms = int(duration_seconds * 1000)
logger.info(f"[FeiShu] Video duration: {duration_seconds:.2f}s ({duration_ms}ms)")
return duration_ms
else:
logger.warning(f"[FeiShu] Failed to get video duration via ffprobe: {result.stderr}")
return 0
except FileNotFoundError:
logger.warning("[FeiShu] ffprobe not found, video duration will be 0. Install ffmpeg to fix this.")
return 0
except Exception as e:
logger.warning(f"[FeiShu] Failed to get video duration: {e}")
return 0
def _upload_video_url(self, video_url, access_token):
"""
Upload video to Feishu and return video info (file_key and duration)
Supports:
- file:// URLs for local files
- http(s):// URLs (download then upload)
Returns:
dict with 'file_key' and 'duration' (milliseconds), or None if failed
"""
local_path = None
temp_file = None
try:
# For file:// URLs (local files), upload directly
if video_url.startswith("file://"):
local_path = video_url[7:] # Remove file:// prefix
if not os.path.exists(local_path):
logger.error(f"[FeiShu] local video file not found: {local_path}")
return None
else:
# For HTTP URLs, download first
logger.info(f"[FeiShu] Downloading video from URL: {video_url}")
response = requests.get(video_url, timeout=(5, 60))
if response.status_code != 200:
logger.error(f"[FeiShu] download video failed, status={response.status_code}")
return None
# Save to temp file
import uuid
file_name = os.path.basename(video_url) or "video.mp4"
temp_file = str(uuid.uuid4()) + "_" + file_name
with open(temp_file, "wb") as file:
file.write(response.content)
logger.info(f"[FeiShu] Video downloaded, size={len(response.content)} bytes")
local_path = temp_file
# Get video duration
duration = self._get_video_duration(local_path)
# Upload to Feishu
file_name = os.path.basename(local_path)
file_ext = os.path.splitext(file_name)[1].lower()
file_type_map = {'.mp4': 'mp4'}
file_type = file_type_map.get(file_ext, 'mp4')
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {
'file_type': file_type,
'file_name': file_name
}
# Add duration only if available (required for video/audio)
if duration:
data['duration'] = duration # Must be int, not string
headers = {'Authorization': f'Bearer {access_token}'}
logger.info(f"[FeiShu] Uploading video: file_name={file_name}, duration={duration}ms")
with open(local_path, "rb") as file:
upload_response = requests.post(
upload_url,
files={"file": file},
data=data,
headers=headers,
timeout=(5, 60)
)
logger.info(
f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
# Add duration to the response data (API doesn't return it)
upload_data = response_data.get("data")
upload_data['duration'] = duration # Add our calculated duration
logger.info(
f"[FeiShu] Upload complete: file_key={upload_data.get('file_key')}, duration={duration}ms")
return upload_data
else:
logger.error(f"[FeiShu] upload video failed: {response_data}")
return None
except Exception as e:
logger.error(f"[FeiShu] upload video exception: {e}")
return None
finally:
# Clean up temp file
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
logger.warning(f"[FeiShu] Failed to remove temp file {temp_file}: {e}")
def _upload_file_url(self, file_url, access_token):
"""
Upload file to Feishu
Supports both local files (file://) and HTTP URLs
"""
logger.debug(f"[FeiShu] start process file, file_url={file_url}")
# Check if it's a local file path (file:// protocol)
if file_url.startswith("file://"):
local_path = file_url[7:] # Remove "file://" prefix
logger.info(f"[FeiShu] uploading local file: {local_path}")
if not os.path.exists(local_path):
logger.error(f"[FeiShu] local file not found: {local_path}")
return None
# Get file info
file_name = os.path.basename(local_path)
file_ext = os.path.splitext(file_name)[1].lower()
# Determine file type for Feishu API
# Feishu supports: opus, mp4, pdf, doc, xls, ppt, stream (other types)
file_type_map = {
'.opus': 'opus',
'.mp4': 'mp4',
'.pdf': 'pdf',
'.doc': 'doc', '.docx': 'doc',
'.xls': 'xls', '.xlsx': 'xls',
'.ppt': 'ppt', '.pptx': 'ppt',
}
file_type = file_type_map.get(file_ext, 'stream') # Default to stream for other types
# Upload file to Feishu
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {'file_type': file_type, 'file_name': file_name}
headers = {'Authorization': f'Bearer {access_token}'}
try:
with open(local_path, "rb") as file:
upload_response = requests.post(
upload_url,
files={"file": file},
data=data,
headers=headers,
timeout=(5, 30) # 5s connect, 30s read timeout
)
logger.info(
f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
response_data = upload_response.json()
if response_data.get("code") == 0:
return response_data.get("data").get("file_key")
else:
logger.error(f"[FeiShu] upload file failed: {response_data}")
return None
except Exception as e:
logger.error(f"[FeiShu] upload file exception: {e}")
return None
# For HTTP URLs, download first then upload
try:
response = requests.get(file_url, timeout=(5, 30))
if response.status_code != 200:
logger.error(f"[FeiShu] download file failed, status={response.status_code}")
return None
# Save to temp file
import uuid
file_name = os.path.basename(file_url)
temp_name = str(uuid.uuid4()) + "_" + file_name
with open(temp_name, "wb") as file:
file.write(response.content)
# Upload
file_ext = os.path.splitext(file_name)[1].lower()
file_type_map = {
'.opus': 'opus', '.mp4': 'mp4', '.pdf': 'pdf',
'.doc': 'doc', '.docx': 'doc',
'.xls': 'xls', '.xlsx': 'xls',
'.ppt': 'ppt', '.pptx': 'ppt',
}
file_type = file_type_map.get(file_ext, 'stream')
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
data = {'file_type': file_type, 'file_name': file_name}
headers = {'Authorization': f'Bearer {access_token}'}
with open(temp_name, "rb") as file:
upload_response = requests.post(upload_url, files={"file": file}, data=data, headers=headers)
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
response_data = upload_response.json()
os.remove(temp_name) # Clean up temp file
if response_data.get("code") == 0:
return response_data.get("data").get("file_key")
else:
logger.error(f"[FeiShu] upload file failed: {response_data}")
return None
except Exception as e:
logger.error(f"[FeiShu] upload file from URL exception: {e}")
return None
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
# Set session_id based on chat type
if cmsg.is_group:
# Group chat: check if group_shared_session is enabled
if conf().get("group_shared_session", True):
# All users in the group share the same session context
context["session_id"] = cmsg.other_user_id # group_id
else:
# Each user has their own session within the group
# This ensures:
# - Same user in different groups have separate conversation histories
# - Same user in private chat and group chat have separate histories
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
else:
# Private chat: use user_id only
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
# 1.文本请求
# 图片生成处理
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
elif context.type == ContextType.VOICE:
# 2.语音请求
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
context["desire_rtype"] = ReplyType.VOICE
return context
class FeishuController:
"""
HTTP服务器控制器用于webhook模式
"""
# 类常量
FAILED_MSG = '{"success": false}'
SUCCESS_MSG = '{"success": true}'
@@ -171,80 +871,10 @@ class FeishuController:
# 处理消息事件
event = request.get("event")
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
if not event.get("message") or not event.get("sender"):
logger.warning(f"[FeiShu] invalid message, msg={request}")
return self.FAILED_MSG
msg = event.get("message")
channel._handle_message_event(event)
# 幂等判断
if channel.receivedMsgs.get(msg.get("message_id")):
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
return self.SUCCESS_MSG
channel.receivedMsgs[msg.get("message_id")] = True
is_group = False
chat_type = msg.get("chat_type")
if chat_type == "group":
if not msg.get("mentions") and msg.get("message_type") == "text":
# 群聊中未@不响应
return self.SUCCESS_MSG
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
# 不是@机器人,不响应
return self.SUCCESS_MSG
# 群聊
is_group = True
receive_id_type = "chat_id"
elif chat_type == "p2p":
receive_id_type = "open_id"
else:
logger.warning("[FeiShu] message ignore")
return self.SUCCESS_MSG
# 构造飞书消息对象
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
if not feishu_msg:
return self.SUCCESS_MSG
context = self._compose_context(
feishu_msg.ctype,
feishu_msg.content,
isgroup=is_group,
msg=feishu_msg,
receive_id_type=receive_id_type,
no_need_at=True
)
if context:
channel.produce(context)
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
return self.SUCCESS_MSG
except Exception as e:
logger.error(e)
return self.FAILED_MSG
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
# 1.文本请求
# 图片生成处理
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
elif context.type == ContextType.VOICE:
# 2.语音请求
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
context["desire_rtype"] = ReplyType.VOICE
return context

View File

@@ -1,10 +1,13 @@
from bridge.context import ContextType
from channel.chat_message import ChatMessage
import json
import os
import requests
from common.log import logger
from common.tmp_dir import TmpDir
from common import utils
from common.utils import expand_path
from config import conf
class FeishuMessage(ChatMessage):
@@ -22,6 +25,119 @@ class FeishuMessage(ChatMessage):
self.ctype = ContextType.TEXT
content = json.loads(msg.get('content'))
self.content = content.get("text").strip()
elif msg_type == "image":
# 单张图片消息:下载并缓存,等待用户提问时一起发送
self.ctype = ContextType.IMAGE
content = json.loads(msg.get("content"))
image_key = content.get("image_key")
# 下载图片到工作空间临时目录
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
image_path = os.path.join(tmp_dir, f"{image_key}.png")
# 下载图片
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.get('message_id')}/resources/{image_key}"
headers = {"Authorization": "Bearer " + access_token}
params = {"type": "image"}
response = requests.get(url=url, headers=headers, params=params)
if response.status_code == 200:
with open(image_path, "wb") as f:
f.write(response.content)
logger.info(f"[FeiShu] Downloaded single image, key={image_key}, path={image_path}")
self.content = image_path
self.image_path = image_path # 保存图片路径
else:
logger.error(f"[FeiShu] Failed to download single image, key={image_key}, status={response.status_code}")
self.content = f"[图片下载失败: {image_key}]"
self.image_path = None
elif msg_type == "post":
# 富文本消息,可能包含图片、文本等多种元素
content = json.loads(msg.get("content"))
# 飞书富文本消息结构content 直接包含 title 和 content 数组
# 不是嵌套在 post 字段下
title = content.get("title", "")
content_list = content.get("content", [])
logger.info(f"[FeiShu] Post message - title: '{title}', content_list length: {len(content_list)}")
# 收集所有图片和文本
image_keys = []
text_parts = []
if title:
text_parts.append(title)
for block in content_list:
logger.debug(f"[FeiShu] Processing block: {block}")
# block 本身就是元素列表
if not isinstance(block, list):
continue
for element in block:
element_tag = element.get("tag")
logger.debug(f"[FeiShu] Element tag: {element_tag}, element: {element}")
if element_tag == "img":
# 找到图片元素
image_key = element.get("image_key")
if image_key:
image_keys.append(image_key)
elif element_tag == "text":
# 文本元素
text_content = element.get("text", "")
if text_content:
text_parts.append(text_content)
logger.info(f"[FeiShu] Parsed - images: {len(image_keys)}, text_parts: {text_parts}")
# 富文本消息统一作为文本消息处理
self.ctype = ContextType.TEXT
if image_keys:
# 如果包含图片,下载并在文本中引用本地路径
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(workspace_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
# 保存图片路径映射
self.image_paths = {}
for image_key in image_keys:
image_path = os.path.join(tmp_dir, f"{image_key}.png")
self.image_paths[image_key] = image_path
def _download_images():
for image_key, image_path in self.image_paths.items():
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{image_key}"
headers = {"Authorization": "Bearer " + access_token}
params = {"type": "image"}
response = requests.get(url=url, headers=headers, params=params)
if response.status_code == 200:
with open(image_path, "wb") as f:
f.write(response.content)
logger.info(f"[FeiShu] Image downloaded from post message, key={image_key}, path={image_path}")
else:
logger.error(f"[FeiShu] Failed to download image from post, key={image_key}, status={response.status_code}")
# 立即下载图片,不使用延迟下载
# 因为 TEXT 类型消息不会调用 prepare()
_download_images()
# 构建消息内容:文本 + 图片路径
content_parts = []
if text_parts:
content_parts.append("\n".join(text_parts).strip())
for image_key, image_path in self.image_paths.items():
content_parts.append(f"[图片: {image_path}]")
self.content = "\n".join(content_parts)
logger.info(f"[FeiShu] Received post message with {len(image_keys)} image(s) and text: {self.content}")
else:
# 纯文本富文本消息
self.content = "\n".join(text_parts).strip() if text_parts else "[富文本消息]"
logger.info(f"[FeiShu] Received post message (text only): {self.content}")
elif msg_type == "file":
self.ctype = ContextType.FILE
content = json.loads(msg.get("content"))
@@ -46,35 +162,6 @@ class FeishuMessage(ChatMessage):
else:
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
self._prepare_fn = _download_file
# elif msg.type == "voice":
# self.ctype = ContextType.VOICE
# self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
#
# def download_voice():
# # 如果响应状态码是200则将响应内容写入本地文件
# response = client.media.download(msg.media_id)
# if response.status_code == 200:
# with open(self.content, "wb") as f:
# f.write(response.content)
# else:
# logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
#
# self._prepare_fn = download_voice
# elif msg.type == "image":
# self.ctype = ContextType.IMAGE
# self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
#
# def download_image():
# # 如果响应状态码是200则将响应内容写入本地文件
# response = client.media.download(msg.media_id)
# if response.status_code == 200:
# with open(self.content, "wb") as f:
# f.write(response.content)
# else:
# logger.info(f"[wechatcom] Failed to download image file, {response.content}")
#
# self._prepare_fn = download_image
else:
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))

100
channel/file_cache.py Normal file
View File

@@ -0,0 +1,100 @@
"""
文件缓存管理器
用于缓存单独发送的文件消息(图片、视频、文档等),在用户提问时自动附加
"""
import time
import logging
logger = logging.getLogger(__name__)
class FileCache:
"""文件缓存管理器,按 session_id 缓存文件TTL=2分钟"""
def __init__(self, ttl=120):
"""
Args:
ttl: 缓存过期时间默认2分钟
"""
self.cache = {}
self.ttl = ttl
def add(self, session_id: str, file_path: str, file_type: str = "image"):
"""
添加文件到缓存
Args:
session_id: 会话ID
file_path: 文件本地路径
file_type: 文件类型image, video, file 等)
"""
if session_id not in self.cache:
self.cache[session_id] = {
'files': [],
'timestamp': time.time()
}
# 添加文件(去重)
file_info = {'path': file_path, 'type': file_type}
if file_info not in self.cache[session_id]['files']:
self.cache[session_id]['files'].append(file_info)
logger.info(f"[FileCache] Added {file_type} to cache for session {session_id}: {file_path}")
def get(self, session_id: str) -> list:
"""
获取缓存的文件列表
Args:
session_id: 会话ID
Returns:
文件信息列表 [{'path': '...', 'type': 'image'}, ...],如果没有或已过期返回空列表
"""
if session_id not in self.cache:
return []
item = self.cache[session_id]
# 检查是否过期
if time.time() - item['timestamp'] > self.ttl:
logger.info(f"[FileCache] Cache expired for session {session_id}, clearing...")
del self.cache[session_id]
return []
return item['files']
def clear(self, session_id: str):
"""
清除指定会话的缓存
Args:
session_id: 会话ID
"""
if session_id in self.cache:
logger.info(f"[FileCache] Cleared cache for session {session_id}")
del self.cache[session_id]
def cleanup_expired(self):
"""清理所有过期的缓存"""
current_time = time.time()
expired_sessions = []
for session_id, item in self.cache.items():
if current_time - item['timestamp'] > self.ttl:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.cache[session_id]
logger.debug(f"[FileCache] Cleaned up expired cache for session {session_id}")
if expired_sessions:
logger.info(f"[FileCache] Cleaned up {len(expired_sessions)} expired cache(s)")
# 全局单例
_file_cache = FileCache()
def get_file_cache() -> FileCache:
"""获取全局文件缓存实例"""
return _file_cache

0
channel/qq/__init__.py Normal file
View File

736
channel/qq/qq_channel.py Normal file
View File

@@ -0,0 +1,736 @@
"""
QQ Bot channel via WebSocket long connection.
Supports:
- Group chat (@bot), single chat (C2C), guild channel, guild DM
- Text / image / file message send & receive
- Heartbeat keep-alive and auto-reconnect with session resume
"""
import base64
import json
import os
import threading
import time
import requests
import websocket
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.qq.qq_message import QQMessage
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.ws_client_compat import websocket_app_run_forever
from config import conf
# Rich media file_type constants
QQ_FILE_TYPE_IMAGE = 1
QQ_FILE_TYPE_VIDEO = 2
QQ_FILE_TYPE_VOICE = 3
QQ_FILE_TYPE_FILE = 4
QQ_API_BASE = "https://api.sgroup.qq.com"
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
# OpCode constants
OP_DISPATCH = 0
OP_HEARTBEAT = 1
OP_IDENTIFY = 2
OP_RESUME = 6
OP_RECONNECT = 7
OP_INVALID_SESSION = 9
OP_HELLO = 10
OP_HEARTBEAT_ACK = 11
# Resumable error codes
RESUMABLE_CLOSE_CODES = {4008, 4009}
@singleton
class QQChannel(ChatChannel):
def __init__(self):
super().__init__()
self.app_id = ""
self.app_secret = ""
self._access_token = ""
self._token_expires_at = 0
self._ws = None
self._ws_thread = None
self._heartbeat_thread = None
self._connected = False
self._stop_event = threading.Event()
self._token_lock = threading.Lock()
self._session_id = None
self._last_seq = None
self._heartbeat_interval = 45000
self._can_resume = False
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
self._msg_seq_counter = {}
conf()["group_name_white_list"] = ["ALL_GROUP"]
conf()["single_chat_prefix"] = [""]
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def startup(self):
self.app_id = conf().get("qq_app_id", "")
self.app_secret = conf().get("qq_app_secret", "")
if not self.app_id or not self.app_secret:
err = "[QQ] qq_app_id and qq_app_secret are required"
logger.error(err)
self.report_startup_error(err)
return
self._refresh_access_token()
if not self._access_token:
err = "[QQ] Failed to get initial access_token"
logger.error(err)
self.report_startup_error(err)
return
self._stop_event.clear()
self._start_ws()
def stop(self):
logger.info("[QQ] stop() called")
self._stop_event.set()
if self._ws:
try:
self._ws.close()
except Exception:
pass
self._ws = None
self._connected = False
# ------------------------------------------------------------------
# Access Token
# ------------------------------------------------------------------
def _refresh_access_token(self):
try:
resp = requests.post(
"https://bots.qq.com/app/getAppAccessToken",
json={"appId": self.app_id, "clientSecret": self.app_secret},
timeout=10,
)
resp.raise_for_status()
data = resp.json()
self._access_token = data.get("access_token", "")
expires_in = int(data.get("expires_in", 7200))
self._token_expires_at = time.time() + expires_in - 60
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
except Exception as e:
logger.error(f"[QQ] Failed to refresh access_token: {e}")
def _get_access_token(self) -> str:
with self._token_lock:
if time.time() >= self._token_expires_at:
self._refresh_access_token()
return self._access_token
def _get_auth_headers(self) -> dict:
return {
"Authorization": f"QQBot {self._get_access_token()}",
"Content-Type": "application/json",
}
# ------------------------------------------------------------------
# WebSocket connection
# ------------------------------------------------------------------
def _get_ws_url(self) -> str:
try:
resp = requests.get(
f"{QQ_API_BASE}/gateway",
headers=self._get_auth_headers(),
timeout=10,
)
resp.raise_for_status()
url = resp.json().get("url", "")
logger.debug(f"[QQ] Gateway URL: {url}")
return url
except Exception as e:
logger.error(f"[QQ] Failed to get gateway URL: {e}")
return ""
def _start_ws(self):
ws_url = self._get_ws_url()
if not ws_url:
logger.error("[QQ] Cannot start WebSocket without gateway URL")
self.report_startup_error("Failed to get gateway URL")
return
def _on_open(ws):
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
def _on_message(ws, raw):
try:
data = json.loads(raw)
self._handle_ws_message(data)
except Exception as e:
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
def _on_error(ws, error):
logger.error(f"[QQ] WebSocket error: {error}")
def _on_close(ws, close_status_code, close_msg):
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
self._connected = False
if not self._stop_event.is_set():
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
self._can_resume = True
logger.info("[QQ] Will attempt resume in 3s...")
time.sleep(3)
else:
self._can_resume = False
logger.info("[QQ] Will reconnect in 5s...")
time.sleep(5)
if not self._stop_event.is_set():
self._start_ws()
self._ws = websocket.WebSocketApp(
ws_url,
on_open=_on_open,
on_message=_on_message,
on_error=_on_error,
on_close=_on_close,
)
def run_forever():
try:
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
except (SystemExit, KeyboardInterrupt):
logger.info("[QQ] WebSocket thread interrupted")
except Exception as e:
logger.error(f"[QQ] WebSocket run_forever error: {e}")
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
self._ws_thread.start()
self._ws_thread.join()
def _ws_send(self, data: dict):
if self._ws:
self._ws.send(json.dumps(data, ensure_ascii=False))
# ------------------------------------------------------------------
# Identify & Resume & Heartbeat
# ------------------------------------------------------------------
def _send_identify(self):
self._ws_send({
"op": OP_IDENTIFY,
"d": {
"token": f"QQBot {self._get_access_token()}",
"intents": DEFAULT_INTENTS,
"shard": [0, 1],
"properties": {
"$os": "linux",
"$browser": "chatgpt-on-wechat",
"$device": "chatgpt-on-wechat",
},
},
})
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
def _send_resume(self):
self._ws_send({
"op": OP_RESUME,
"d": {
"token": f"QQBot {self._get_access_token()}",
"session_id": self._session_id,
"seq": self._last_seq,
},
})
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
def _start_heartbeat(self, interval_ms: int):
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
return
self._heartbeat_interval = interval_ms
interval_sec = interval_ms / 1000.0
def heartbeat_loop():
while not self._stop_event.is_set() and self._connected:
try:
self._ws_send({
"op": OP_HEARTBEAT,
"d": self._last_seq,
})
except Exception as e:
logger.warning(f"[QQ] Heartbeat send failed: {e}")
break
self._stop_event.wait(interval_sec)
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
self._heartbeat_thread.start()
# ------------------------------------------------------------------
# Incoming message dispatch
# ------------------------------------------------------------------
def _handle_ws_message(self, data: dict):
op = data.get("op")
d = data.get("d")
t = data.get("t")
s = data.get("s")
if s is not None:
self._last_seq = s
if op == OP_HELLO:
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
self._heartbeat_interval = heartbeat_interval
if self._can_resume and self._session_id:
self._send_resume()
else:
self._send_identify()
elif op == OP_HEARTBEAT_ACK:
pass
elif op == OP_HEARTBEAT:
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
elif op == OP_RECONNECT:
logger.warning("[QQ] Server requested reconnect")
self._can_resume = True
if self._ws:
self._ws.close()
elif op == OP_INVALID_SESSION:
logger.warning("[QQ] Invalid session, re-identifying...")
self._session_id = None
self._can_resume = False
time.sleep(2)
self._send_identify()
elif op == OP_DISPATCH:
if t == "READY":
self._session_id = d.get("session_id", "")
user = d.get("user", {})
bot_name = user.get('username', '')
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
self._connected = True
self._can_resume = False
self._start_heartbeat(self._heartbeat_interval)
self.report_startup_success()
elif t == "RESUMED":
logger.info("[QQ] Session resumed successfully")
self._connected = True
self._can_resume = False
self._start_heartbeat(self._heartbeat_interval)
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
self._handle_msg_event(d, t)
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
logger.info(f"[QQ] Event: {t}")
else:
logger.debug(f"[QQ] Dispatch event: {t}")
# ------------------------------------------------------------------
# Message event handling
# ------------------------------------------------------------------
def _handle_msg_event(self, event_data: dict, event_type: str):
msg_id = event_data.get("id", "")
if self.received_msgs.get(msg_id):
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
return
self.received_msgs[msg_id] = True
try:
qq_msg = QQMessage(event_data, event_type)
except NotImplementedError as e:
logger.warning(f"[QQ] {e}")
return
except Exception as e:
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
return
is_group = qq_msg.is_group
from channel.file_cache import get_file_cache
file_cache = get_file_cache()
if is_group:
session_id = qq_msg.other_user_id
else:
session_id = qq_msg.from_user_id
if qq_msg.ctype == ContextType.IMAGE:
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
file_cache.add(session_id, qq_msg.image_path, file_type="image")
logger.info(f"[QQ] Image cached for session {session_id}")
return
if qq_msg.ctype == ContextType.TEXT:
cached_files = file_cache.get(session_id)
if cached_files:
file_refs = []
for fi in cached_files:
ftype = fi["type"]
fpath = fi["path"]
if ftype == "image":
file_refs.append(f"[图片: {fpath}]")
elif ftype == "video":
file_refs.append(f"[视频: {fpath}]")
else:
file_refs.append(f"[文件: {fpath}]")
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
file_cache.clear(session_id)
context = self._compose_context(
qq_msg.ctype,
qq_msg.content,
isgroup=is_group,
msg=qq_msg,
no_need_at=True,
)
if context:
self.produce(context)
# ------------------------------------------------------------------
# _compose_context
# ------------------------------------------------------------------
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if "channel_type" not in context:
context["channel_type"] = self.channel_type
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
cmsg = context["msg"]
if cmsg.is_group:
context["session_id"] = cmsg.other_user_id
else:
context["session_id"] = cmsg.from_user_id
context["receiver"] = cmsg.other_user_id
if ctype == ContextType.TEXT:
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
return context
# ------------------------------------------------------------------
# Send reply
# ------------------------------------------------------------------
def send(self, reply: Reply, context: Context):
msg = context.get("msg")
is_group = context.get("isgroup", False)
receiver = context.get("receiver", "")
if not msg:
# Active send (e.g. scheduled tasks), no original message to reply to
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
receiver, is_group)
return
event_type = getattr(msg, "event_type", "")
msg_id = getattr(msg, "msg_id", "")
if reply.type == ReplyType.TEXT:
self._send_text(reply.content, msg, event_type, msg_id)
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
self._send_image(reply.content, msg, event_type, msg_id)
elif reply.type == ReplyType.FILE:
if hasattr(reply, "text_content") and reply.text_content:
self._send_text(reply.text_content, msg, event_type, msg_id)
time.sleep(0.3)
self._send_file(reply.content, msg, event_type, msg_id)
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
else:
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
self._send_text(str(reply.content), msg, event_type, msg_id)
# ------------------------------------------------------------------
# Send helpers
# ------------------------------------------------------------------
def _get_next_msg_seq(self, msg_id: str) -> int:
seq = self._msg_seq_counter.get(msg_id, 1)
self._msg_seq_counter[msg_id] = seq + 1
return seq
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
"""Build the API URL and base body dict for sending a message."""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
body = {
"msg_id": msg_id,
"msg_seq": self._get_next_msg_seq(msg_id),
}
return url, body, "group", group_openid
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
body = {
"msg_id": msg_id,
"msg_seq": self._get_next_msg_seq(msg_id),
}
return url, body, "c2c", user_openid
elif event_type == "AT_MESSAGE_CREATE":
channel_id = msg._rawmsg.get("channel_id", "")
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
body = {"msg_id": msg_id}
return url, body, "channel", channel_id
elif event_type == "DIRECT_MESSAGE_CREATE":
guild_id = msg._rawmsg.get("guild_id", "")
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
body = {"msg_id": msg_id}
return url, body, "dm", guild_id
return None, None, None, None
def _post_message(self, url: str, body: dict, event_type: str):
try:
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
if resp.status_code in (200, 201, 202, 204):
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
else:
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
f"body={resp.text}")
except Exception as e:
logger.error(f"[QQ] Send message error: {e}")
# ------------------------------------------------------------------
# Active send (no original message, e.g. scheduled tasks)
# ------------------------------------------------------------------
def _active_send_text(self, content: str, receiver: str, is_group: bool):
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
if not receiver:
logger.warning("[QQ] No receiver for active send")
return
if is_group:
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
else:
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
body = {
"content": content,
"msg_type": 0,
}
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
self._post_message(url, body, event_label)
# ------------------------------------------------------------------
# Send text
# ------------------------------------------------------------------
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
if not url:
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
return
body["content"] = content
body["msg_type"] = 0
self._post_message(url, body, event_type)
# ------------------------------------------------------------------
# Rich media upload & send (image / video / file)
# ------------------------------------------------------------------
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
event_type: str) -> str:
"""
Upload media via QQ rich media API and return file_info.
For group: POST /v2/groups/{group_openid}/files
For c2c: POST /v2/users/{openid}/files
"""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
or msg.from_user_id)
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
else:
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
return ""
upload_body = {
"file_type": file_type,
"url": file_url,
"srv_send_msg": False,
}
try:
resp = requests.post(
upload_url, json=upload_body,
headers=self._get_auth_headers(), timeout=30,
)
if resp.status_code in (200, 201):
data = resp.json()
file_info = data.get("file_info", "")
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
f"file_uuid={data.get('file_uuid', '')}")
return file_info
else:
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
f"body={resp.text}")
return ""
except Exception as e:
logger.error(f"[QQ] Rich media upload error: {e}")
return ""
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
event_type: str) -> str:
"""Upload local file via base64 file_data field."""
if event_type == "GROUP_AT_MESSAGE_CREATE":
group_openid = msg._rawmsg.get("group_openid", "")
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
or msg.from_user_id)
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
else:
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
return ""
try:
with open(file_path, "rb") as f:
file_data = base64.b64encode(f.read()).decode("utf-8")
except Exception as e:
logger.error(f"[QQ] Failed to read file for upload: {e}")
return ""
upload_body = {
"file_type": file_type,
"file_data": file_data,
"srv_send_msg": False,
}
try:
resp = requests.post(
upload_url, json=upload_body,
headers=self._get_auth_headers(), timeout=30,
)
if resp.status_code in (200, 201):
data = resp.json()
file_info = data.get("file_info", "")
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
f"file_uuid={data.get('file_uuid', '')}")
return file_info
else:
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
f"body={resp.text}")
return ""
except Exception as e:
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
return ""
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send a message with msg_type=7 (rich media) using file_info."""
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
if not url:
return
body["msg_type"] = 7
body["media"] = {"file_info": file_info}
self._post_message(url, body, event_type)
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send image reply. Supports URL and local file path."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
return
if img_path_or_url.startswith("file://"):
img_path_or_url = img_path_or_url[7:]
if img_path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
elif os.path.exists(img_path_or_url):
file_info = self._upload_rich_media_base64(
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
else:
logger.error(f"[QQ] Image not found: {img_path_or_url}")
self._send_text("[Image send failed]", msg, event_type, msg_id)
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
self._send_text("[Image upload failed]", msg, event_type, msg_id)
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
"""Send file reply."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
return
if file_path_or_url.startswith("file://"):
file_path_or_url = file_path_or_url[7:]
if file_path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
elif os.path.exists(file_path_or_url):
file_info = self._upload_rich_media_base64(
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
else:
logger.error(f"[QQ] File not found: {file_path_or_url}")
self._send_text("[File send failed]", msg, event_type, msg_id)
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
self._send_text("[File upload failed]", msg, event_type, msg_id)
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
msg_id: str, file_type: int):
"""Generic media send for video/voice etc."""
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
self._send_text(str(path_or_url), msg, event_type, msg_id)
return
if path_or_url.startswith("file://"):
path_or_url = path_or_url[7:]
if path_or_url.startswith(("http://", "https://")):
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
elif os.path.exists(path_or_url):
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
else:
logger.error(f"[QQ] Media not found: {path_or_url}")
return
if file_info:
self._send_media_msg(file_info, msg, event_type, msg_id)
else:
logger.error(f"[QQ] Media upload failed: {path_or_url}")

123
channel/qq/qq_message.py Normal file
View File

@@ -0,0 +1,123 @@
import os
import requests
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.utils import expand_path
from config import conf
def _get_tmp_dir() -> str:
"""Return the workspace tmp directory (absolute path), creating it if needed."""
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
tmp_dir = os.path.join(ws_root, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
class QQMessage(ChatMessage):
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
def __init__(self, event_data: dict, event_type: str):
super().__init__(event_data)
self.msg_id = event_data.get("id", "")
self.create_time = event_data.get("timestamp", "")
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
self.event_type = event_type
author = event_data.get("author", {})
from_user_id = author.get("member_openid", "") or author.get("id", "")
group_openid = event_data.get("group_openid", "")
content = event_data.get("content", "").strip()
attachments = event_data.get("attachments", [])
has_image = any(
a.get("content_type", "").startswith("image/") for a in attachments
) if attachments else False
if has_image and not content:
self.ctype = ContextType.IMAGE
img_attachment = next(
a for a in attachments if a.get("content_type", "").startswith("image/")
)
img_url = img_attachment.get("url", "")
if img_url and not img_url.startswith("http"):
img_url = "https://" + img_url
tmp_dir = _get_tmp_dir()
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
try:
resp = requests.get(img_url, timeout=30)
resp.raise_for_status()
with open(image_path, "wb") as f:
f.write(resp.content)
self.content = image_path
self.image_path = image_path
logger.info(f"[QQ] Image downloaded: {image_path}")
except Exception as e:
logger.error(f"[QQ] Failed to download image: {e}")
self.content = "[Image download failed]"
self.image_path = None
elif has_image and content:
self.ctype = ContextType.TEXT
image_paths = []
tmp_dir = _get_tmp_dir()
for idx, att in enumerate(attachments):
if not att.get("content_type", "").startswith("image/"):
continue
img_url = att.get("url", "")
if img_url and not img_url.startswith("http"):
img_url = "https://" + img_url
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
try:
resp = requests.get(img_url, timeout=30)
resp.raise_for_status()
with open(img_path, "wb") as f:
f.write(resp.content)
image_paths.append(img_path)
except Exception as e:
logger.error(f"[QQ] Failed to download mixed image: {e}")
content_parts = [content]
for p in image_paths:
content_parts.append(f"[图片: {p}]")
self.content = "\n".join(content_parts)
else:
self.ctype = ContextType.TEXT
self.content = content
if event_type == "GROUP_AT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
self.other_user_id = group_openid
self.actual_user_id = from_user_id
self.actual_user_nickname = from_user_id
elif event_type == "C2C_MESSAGE_CREATE":
user_openid = author.get("user_openid", "") or from_user_id
self.from_user_id = user_openid
self.to_user_id = ""
self.other_user_id = user_openid
self.actual_user_id = user_openid
elif event_type == "AT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
channel_id = event_data.get("channel_id", "")
self.other_user_id = channel_id
self.actual_user_id = from_user_id
self.actual_user_nickname = author.get("username", from_user_id)
elif event_type == "DIRECT_MESSAGE_CREATE":
self.from_user_id = from_user_id
self.to_user_id = ""
guild_id = event_data.get("guild_id", "")
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
self.actual_user_id = from_user_id
self.actual_user_nickname = author.get("username", from_user_id)
else:
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
f"from={self.from_user_id}, content_len={len(self.content)}")

View File

@@ -78,6 +78,7 @@ class TerminalChannel(ChatChannel):
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
context["isgroup"] = False
if context:
self.produce(context)
else:

Some files were not shown because too many files have changed in this diff Show More