Browse Source

fix: ensure get access_token thread-safe

master
lanvent 1 year ago
parent
commit
c6601aaeed
4 changed files with 29 additions and 21 deletions
  1. +1
    -1
      channel/wechatcom/README.md
  2. +3
    -3
      channel/wechatcom/wechatcomapp_channel.py
  3. +21
    -0
      channel/wechatcom/wechatcomapp_client.py
  4. +4
    -17
      channel/wechatcom/wechatcomapp_message.py

+ 1
- 1
channel/wechatcom/README.md View File

@@ -54,4 +54,4 @@


AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。 AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。


<img width="360" src="./docs/images/aigcopen.png">
<img width="200" src="../../docs/images/aigcopen.png">

+ 3
- 3
channel/wechatcom/wechatcomapp_channel.py View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python
# -*- coding=utf-8 -*- # -*- coding=utf-8 -*-
import io import io
import os import os
@@ -6,7 +5,7 @@ import textwrap


import requests import requests
import web import web
from wechatpy.enterprise import WeChatClient, create_reply, parse_message
from wechatpy.enterprise import create_reply, parse_message
from wechatpy.enterprise.crypto import WeChatCrypto from wechatpy.enterprise.crypto import WeChatCrypto
from wechatpy.enterprise.exceptions import InvalidCorpIdException from wechatpy.enterprise.exceptions import InvalidCorpIdException
from wechatpy.exceptions import InvalidSignatureException, WeChatClientException from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
@@ -14,6 +13,7 @@ from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
from bridge.context import Context from bridge.context import Context
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechatcom.wechatcomapp_client import WechatComAppClient
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
from common.log import logger from common.log import logger
from common.singleton import singleton from common.singleton import singleton
@@ -38,7 +38,7 @@ class WechatComAppChannel(ChatChannel):
"[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key) "[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
) )
self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id) self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
self.client = WeChatClient(self.corp_id, self.secret) # todo: 这里可能有线程安全问题
self.client = WechatComAppClient(self.corp_id, self.secret)


def startup(self): def startup(self):
# start message listener # start message listener


+ 21
- 0
channel/wechatcom/wechatcomapp_client.py View File

@@ -0,0 +1,21 @@
import threading
import time

from wechatpy.enterprise import WeChatClient


class WechatComAppClient(WeChatClient):
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
self.fetch_access_token_lock = threading.Lock()

def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
with self.fetch_access_token_lock:
access_token = self.session.get(self.access_token_key)
if access_token:
if not self.expires_at:
return access_token
timestamp = time.time()
if self.expires_at - timestamp > 60:
return access_token
return super().fetch_access_token()

+ 4
- 17
channel/wechatcom/wechatcomapp_message.py View File

@@ -1,14 +1,9 @@
import re

import requests
from wechatpy.enterprise import WeChatClient from wechatpy.enterprise import WeChatClient


from bridge.context import ContextType from bridge.context import ContextType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
from common.log import logger from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from lib import itchat
from lib.itchat.content import *




class WechatComAppMessage(ChatMessage): class WechatComAppMessage(ChatMessage):
@@ -23,9 +18,7 @@ class WechatComAppMessage(ChatMessage):
self.content = msg.content self.content = msg.content
elif msg.type == "voice": elif msg.type == "voice":
self.ctype = ContextType.VOICE self.ctype = ContextType.VOICE
self.content = (
TmpDir().path() + msg.media_id + "." + msg.format
) # content直接存临时目录路径
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径


def download_voice(): def download_voice():
# 如果响应状态码是200,则将响应内容写入本地文件 # 如果响应状态码是200,则将响应内容写入本地文件
@@ -34,9 +27,7 @@ class WechatComAppMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info(
f"[wechatcom] Failed to download voice file, {response.content}"
)
logger.info(f"[wechatcom] Failed to download voice file, {response.content}")


self._prepare_fn = download_voice self._prepare_fn = download_voice
elif msg.type == "image": elif msg.type == "image":
@@ -50,15 +41,11 @@ class WechatComAppMessage(ChatMessage):
with open(self.content, "wb") as f: with open(self.content, "wb") as f:
f.write(response.content) f.write(response.content)
else: else:
logger.info(
f"[wechatcom] Failed to download image file, {response.content}"
)
logger.info(f"[wechatcom] Failed to download image file, {response.content}")


self._prepare_fn = download_image self._prepare_fn = download_image
else: else:
raise NotImplementedError(
"Unsupported message type: Type:{} ".format(msg.type)
)
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))


self.from_user_id = msg.source self.from_user_id = msg.source
self.to_user_id = msg.target self.to_user_id = msg.target


Loading…
Cancel
Save