py3_cookbook_notes_02
文章目录
最近在看Python Cookbook第三版,将看书过程中一些平时不太容易注意的知识点记录下来。
函数
可接受任意数量参数的函数
def avg(first, *rest):
return (first + sum(rest)) / (1 + len(rest))
# Sample use
avg(1, 2) # 1.5
avg(1, 2, 3, 4) # 2.5
import html
def make_element(name, value, **attrs):
keyvals = [' %s="%s"' % item for item in attrs.items()]
attr_str = ''.join(keyvals)
element = '<{name}{attrs}>{value}</{name}>'.format(
name=name,
attrs=attr_str,
value=html.escape(value))
return element
# Example
# Creates '<item size="large" quantity="6">Albatross</item>'
make_element('item', 'Albatross', size='large', quantity=6)
# Creates '<p><spam></p>'
make_element('p', '<spam>')
只接受关键字参数的函数
def recv(maxsize, *, block):
'Receives a message'
pass
recv(1024, True) # TypeError
recv(1024, block=True) # Ok
def mininum(*values, clip=None):
m = min(values)
if clip is not None:
m = clip if clip > m else m
return m
minimum(1, 5, 2, -5, 10) # Returns -5
minimum(1, 5, 2, -5, 10, clip=0) # Returns 0
给函数参数增加元信息
def add(x:int, y:int) -> int:
return x + y
定义匿名或内联函数
>>> add = lambda x, y: x + y
>>> add(2,3)
5
>>> add('hello', 'world')
'helloworld'
>>> names = ['David Beazley', 'Brian Jones',
... 'Raymond Hettinger', 'Ned Batchelder']
>>> sorted(names, key=lambda name: name.split()[-1].lower())
['Ned Batchelder', 'David Beazley', 'Raymond Hettinger', 'Brian Jones']
匿名函数捕获变量值
>>> x = 10
>>> a = lambda y, x=x: x + y
>>> x = 20
>>> b = lambda y, x=x: x + y
>>> a(10)
20
>>> b(10)
30
减少可调用对象的参数个数
def spam(a, b, c, d):
print(a, b, c, d)
>>> from functools import partial
>>> s1 = partial(spam, 1) # a = 1
>>> s1(2, 3, 4)
1 2 3 4
>>> s2 = partial(spam, d=42) # d = 42
>>> s2(1, 2, 3)
1 2 3 42
>>> s2(4, 5, 5)
4 5 5 42
>>> s3 = partial(spam, 1, 2, d=42) # a = 1, b = 2, d = 42
>>> s3(3)
将单方法的类转换为函数
from urllib.request import urlopen
class UrlTemplate:
def __init__(self, template):
self.template = template
def open(self, **kwargs):
return urlopen(self.template.format_map(kwargs))
# Example use. Download stock data from yahoo
yahoo = UrlTemplate('http://finance.yahoo.com/d/quotes.csv?s={names}&f={fields}')
for line in yahoo.open(names='IBM,AAPL,FB', fields='sl1c1v'):
print(line.decode('utf-8'))
转换写法:
from urllib.request import urlopen
def urltemplate(template):
def opener(**kwargs):
return urlopen(template.format_map(kwargs))
return opener
# Example use
yahoo = urltemplate('http://finance.yahoo.com/d/quotes.csv?s={names}&f={fields}')
for line in yahoo(names='IBM,AAPL,FB', fields='sl1c1v'):
print(line.decode('utf-8'))
带额外状态信息的回调函数
3种解决方法:
- 创建一个类的实例
>>> def add(x, y):
... return x + y
...
class ResultHandler:
def __init__(self):
self.sequence = 0
def handler(self, result):
self.sequence += 1
print('[{}] Got: {}'.format(self.sequence, result))
>>> r = ResultHandler()
>>> apply_async(add, (2, 3), callback=r.handler)
[1] Got: 5
- 用一个闭包捕获状态值
>>> def add(x, y):
... return x + y
...
def make_handler():
sequence = 0
def handler(result):
nonlocal sequence
sequence += 1
print('[{}] Got: {}'.format(sequence, result))
return handler
>>> handler = make_handler()
>>> apply_async(add, (2, 3), callback=handler)
内联回调函数
def apply_async(func, args, *, callback):
# Compute the result
result = func(*args)
# Invoke the callback with the result
callback(result)
class Async:
def __init__(self, func, args):
self.func = func
self.args = args
def inlined_async(func):
@wraps(func)
def wrapper(*args):
f = func(*args) # func is test function, f is a generator
result_queue = Queue()
result_queue.put(None)
while True:
result = result_queue.get()
try:
a = f.send(result) # a is Async object
apply_async(a.func, a.args, callback=result_queue.put)
except StopIteration:
break
return wrapper
def add(x, y):
return x + y
@inlined_async
def test():
r = yield Async(add, (2, 3))
print(r)
r = yield Async(add, ('hello', 'world'))
print(r)
for n in range(10):
r = yield Async(add, (n, n))
print(r)
print('Goodbye')
test()
这里特别解释一下操作generator的两个方法send
和next
:
https://stackoverflow.com/questions/12637768/python-3-send-method-of-generators
When you use
send
and expressionyield
in a generator, you’re treating it as a coroutine; a separate thread of execution that can run sequentially interleaved but not in parallel with its caller.When the caller executes
R = m.send(a)
, it puts the objecta
into the generator’s input slot, transfers control to the generator, and waits for a response. The generator receives objecta
as the result ofX = yield i
, and runs until it hits anotheryield
expression e.g.Y = yield j
. Then it putsj
into its output slot, transfers control back to the caller, and waits until it gets resumed again. The caller receivesj
as the result ofR = m.send(a)
, and runs until it hits anotherS = m.send(b)
statement, and so on.
R = next(m)
is just the same asR = m.send(None)
; it’s puttingNone
into the generator’s input slot, so if the generator checks the result ofX = yield i
thenX
will beNone
.
访问闭包中定义的变量
def sample():
n = 0
# Closure function
def func():
print('n=', n)
# Accessor methods for n
def get_n():
return n
def set_n(value):
nonlocal n
n = value
# Attach as function attributes
func.get_n = get_n
func.set_n = set_n
return func
类与对象
改变对象的字符串显示
class Pair:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return 'Pair({0.x!r}, {0.y!r})'.format(self)
def __str__(self):
return '({0.x!s}, {0.y!s})'.format(self)
自定义字符串的格式化
_formats = {
'ymd' : '{d.year}-{d.month}-{d.day}',
'mdy' : '{d.month}/{d.day}/{d.year}',
'dmy' : '{d.day}/{d.month}/{d.year}'
}
class Date:
def __init__(self, year, month, day):
self.year = year
self.month = month
self.day = day
def __format__(self, code):
if code == '':
code = 'ymd'
fmt = _formats[code]
return fmt.format(d=self)
>>> d = Date(2012, 12, 21)
>>> format(d)
'2012-12-21'
>>> format(d, 'mdy')
'12/21/2012'
>>> 'The date is {:ymd}'.format(d)
'The date is 2012-12-21'
>>> 'The date is {:mdy}'.format(d)
'The date is 12/21/2012'
让对象支持上下文管理协议
from socket import socket, AF_INET, SOCK_STREAM
class LazyConnection:
def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
self.address = address
self.family = family
self.type = type
self.sock = None
def __enter__(self):
if self.sock is not None:
raise RuntimeError('Already connected')
self.sock = socket(self.family, self.type)
self.sock.connect(self.address)
return self.sock
def __exit__(self, exc_ty, exc_val, tb):
self.sock.close()
self.sock = None
from functools import partial
conn = LazyConnection(('www.python.org', 80))
# Connection closed
with conn as s:
# conn.__enter__() executes: connection open
s.send(b'GET /index.html HTTP/1.0\r\n')
s.send(b'Host: www.python.org\r\n')
s.send(b'\r\n')
resp = b''.join(iter(partial(s.recv, 8192), b''))
# conn.__exit__() executes: connection closed
在类中封装属性名
class A:
def __init__(self):
self._internal = 0 # An internal attribute
self.public = 1 # A public attribute
self.__private_field = 0 # 不能被子类覆盖
def public_method(self):
'''
A public method
'''
pass
def _internal_method(self):
pass
def __private_method(self):
pass
创建可管理的属性
class Person:
def __init__(self, first_name):
self.first_name = first_name
# Getter function
@property
def first_name(self):
return self._first_name
# Setter function
@first_name.setter
def first_name(self, value):
if not isinstance(value, str):
raise TypeError('Expected a string')
self._first_name = value
# Deleter function (optional)
@first_name.deleter
def first_name(self):
raise AttributeError("Can't delete attribute")
调用父类方法
class A:
def spam(self):
print('A.spam')
class B(A):
def spam(self):
print('B.spam')
super().spam() # Call parent spam()
简化数据结构的初始化
class Structure:
# Class variable that specifies expected fields
_fields = []
def __init__(self, *args, **kwargs):
if len(args) != len(self._fields):
raise TypeError('Expected {} arguments'.format(len(self._fields)))
# Set the arguments
for name, value in zip(self._fields, args):
setattr(self, name, value)
# Set the additional arguments (if any)
extra_args = kwargs.keys() - self._fields
for name in extra_args:
setattr(self, name, kwargs.pop(name))
if kwargs:
raise TypeError('Duplicate values for {}'.format(','.join(kwargs)))
class Stock(Structure):
_fields = ['name', 'shares', 'price']
s1 = Stock('ACME', 50, 91.1)
s2 = Stock('ACME', 50, 91.1, date='8/2/2012')
定义接口或者抽象基类
from abc import ABCMeta, abstractmethod
class IStream(metaclass=ABCMeta):
@abstractmethod
def read(self, maxbytes=-1):
pass
@abstractmethod
def write(self, data):
pass
class SocketStream(IStream):
def read(self, maxbytes=-1):
pass
def write(self, data):
pass
# Register the built-in I/O classes as supporting our interface
import io
IStream.register(io.IOBase)
import collections
# Check if x is a sequence
if isinstance(x, collections.Sequence):
...
# Check if x is iterable
if isinstance(x, collections.Iterable):
...
# Check if x has a size
if isinstance(x, collections.Sized):
...
# Check if x is a mapping
if isinstance(x, collections.Mapping):
实现数据模型的类型约束
# Base class. Uses a descriptor to set a value
class Descriptor:
def __init__(self, name=None, **opts):
self.name = name
for key, value in opts.items():
setattr(self, key, value)
def __set__(self, instance, value):
instance.__dict__[self.name] = value
# Descriptor for enforcing types
class Typed(Descriptor):
expected_type = type(None)
def __set__(self, instance, value):
if not isinstance(value, self.expected_type):
raise TypeError('expected ' + str(self.expected_type))
super().__set__(instance, value)
# Descriptor for enforcing values
class Unsigned(Descriptor):
def __set__(self, instance, value):
if value < 0:
raise ValueError('Expected >= 0')
super().__set__(instance, value)
class MaxSized(Descriptor):
def __init__(self, name=None, **opts):
if 'size' not in opts:
raise TypeError('missing size option')
super().__init__(name, **opts)
def __set__(self, instance, value):
if len(value) >= self.size:
raise ValueError('size must be < ' + str(self.size))
super().__set__(instance, value)
class Integer(Typed):
expected_type = int
class UnsignedInteger(Integer, Unsigned):
pass
class Float(Typed):
expected_type = float
class UnsignedFloat(Float, Unsigned):
pass
class String(Typed):
expected_type = str
class SizedString(String, MaxSized):
pass
class Stock:
# Specify constraints
name = SizedString('name', size=8)
shares = UnsignedInteger('shares')
price = UnsignedFloat('price')
def __init__(self, name, shares, price):
self.name = name
self.shares = shares
self.price = price
实现自定义容器
import collections
class A(collections.Iterable):
def __iter__():
pass
属性的代理访问
class A:
def spam(self, x):
pass
def foo(self):
pass
class B2:
"""使用__getattr__的代理,代理方法比较多时候"""
def __init__(self):
self._a = A()
def bar(self):
pass
# Expose all of the methods defined on class A
def __getattr__(self, name):
"""这个方法在访问的attribute不存在的时候被调用
the __getattr__() method is actually a fallback method
that only gets called when an attribute is not found"""
return getattr(self._a, name)
在类中定义多个构造器
import time
class Date:
"""方法一:使用类方法"""
# Primary constructor
def __init__(self, year, month, day):
self.year = year
self.month = month
self.day = day
# Alternate constructor
@classmethod
def today(cls):
t = time.localtime()
return cls(t.tm_year, t.tm_mon, t.tm_mday)
创建不调用init方法的实例
class Date:
def __init__(self, year, month, day):
self.year = year
self.month = month
self.day = day
d = Date.__new__(Date)
>>> data = {'year':2012, 'month':8, 'day':29}
>>> for key, value in data.items():
... setattr(d, key, value)
...
有用的方法扩展其他类的功能
def LoggedMapping(cls):
"""第二种方式:使用类装饰器"""
cls_getitem = cls.__getitem__
cls_setitem = cls.__setitem__
cls_delitem = cls.__delitem__
def __getitem__(self, key):
print('Getting ' + str(key))
return cls_getitem(self, key)
def __setitem__(self, key, value):
print('Setting {} = {!r}'.format(key, value))
return cls_setitem(self, key, value)
def __delitem__(self, key):
print('Deleting ' + str(key))
return cls_delitem(self, key)
cls.__getitem__ = __getitem__
cls.__setitem__ = __setitem__
cls.__delitem__ = __delitem__
return cls
@LoggedMapping
class LoggedDict(dict):
pass
实现状态对象或者状态机
class Connection:
"""新方案——对每个状态定义一个类"""
def __init__(self):
self.new_state(ClosedConnectionState)
def new_state(self, newstate):
self._state = newstate
# Delegate to the state class
def read(self):
return self._state.read(self)
def write(self, data):
return self._state.write(self, data)
def open(self):
return self._state.open(self)
def close(self):
return self._state.close(self)
# Connection state base class
class ConnectionState:
@staticmethod
def read(conn):
raise NotImplementedError()
@staticmethod
def write(conn, data):
raise NotImplementedError()
@staticmethod
def open(conn):
raise NotImplementedError()
@staticmethod
def close(conn):
raise NotImplementedError()
# Implementation of different states
class ClosedConnectionState(ConnectionState):
@staticmethod
def read(conn):
raise RuntimeError('Not open')
@staticmethod
def write(conn, data):
raise RuntimeError('Not open')
@staticmethod
def open(conn):
conn.new_state(OpenConnectionState)
@staticmethod
def close(conn):
raise RuntimeError('Already closed')
class OpenConnectionState(ConnectionState):
@staticmethod
def read(conn):
print('reading')
@staticmethod
def write(conn, data):
print('writing')
@staticmethod
def open(conn):
raise RuntimeError('Already open')
@staticmethod
def close(conn):
conn.new_state(ClosedConnectionState)
c = Connection()
>>> c._state
<class '__main__.ClosedConnectionState'>
>>> c.open()
>>> c._state
<class '__main__.OpenConnectionState'>
>>> c.read()
reading
>>> c.write('hello')
writing
>>> c.close()
>>> c._state
<class '__main__.ClosedConnectionState'>
>>>
通过字符串调用对象方法
import math
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return 'Point({!r:},{!r:})'.format(self.x, self.y)
def distance(self, x, y):
return math.hypot(self.x - x, self.y - y)
p = Point(2, 3)
d = getattr(p, 'distance')(0, 0) # Calls p.distance(0, 0)
import operator
operator.methodcaller('distance', 0, 0)(p)
实现访问者模式(递归实现法)
class Node:
pass
class UnaryOperator(Node):
def __init__(self, operand):
self.operand = operand
class BinaryOperator(Node):
def __init__(self, left, right):
self.left = left
self.right = right
class Add(BinaryOperator):
pass
class Sub(BinaryOperator):
pass
class Mul(BinaryOperator):
pass
class Div(BinaryOperator):
pass
class Negate(UnaryOperator):
pass
class Number(Node):
def __init__(self, value):
self.value = value
class NodeVisitor:
def visit(self, node):
methname = 'visit_' + type(node).__name__
meth = getattr(self, methname, None)
if meth is None:
meth = self.generic_visit
return meth(node)
def generic_visit(self, node):
raise RuntimeError('No {} method'.format('visit_' + type(node).__name__))
class Evaluator(NodeVisitor):
def visit_Number(self, node):
return node.value
def visit_Add(self, node):
return self.visit(node.left) + self.visit(node.right)
def visit_Sub(self, node):
return self.visit(node.left) - self.visit(node.right)
def visit_Mul(self, node):
return self.visit(node.left) * self.visit(node.right)
def visit_Div(self, node):
return self.visit(node.left) / self.visit(node.right)
def visit_Negate(self, node):
return -node.operand
>>> e = Evaluator()
>>> e.visit(t4)
class HTTPHandler:
def handle(self, request):
methname = 'do_' + request.request_method
getattr(self, methname)(request)
def do_GET(self, request):
pass
def do_POST(self, request):
pass
def do_HEAD(self, request):
pass
实现访问者模式(非递归实现法)
import types
class Node:
pass
class NodeVisitor:
def visit(self, node):
stack = [node]
last_result = None
while stack:
try:
last = stack[-1]
if isinstance(last, types.GeneratorType):
stack.append(last.send(last_result))
last_result = None
elif isinstance(last, Node):
stack.append(self._visit(stack.pop()))
else:
last_result = stack.pop()
except StopIteration:
stack.pop()
return last_result
def _visit(self, node):
methname = 'visit_' + type(node).__name__
meth = getattr(self, methname, None)
if meth is None:
meth = self.generic_visit
return meth(node)
def generic_visit(self, node):
raise RuntimeError('No {} method'.format('visit_' + type(node).__name__))
class UnaryOperator(Node):
def __init__(self, operand):
self.operand = operand
class BinaryOperator(Node):
def __init__(self, left, right):
self.left = left
self.right = right
class Add(BinaryOperator):
pass
class Sub(BinaryOperator):
pass
class Mul(BinaryOperator):
pass
class Div(BinaryOperator):
pass
class Negate(UnaryOperator):
pass
class Number(Node):
def __init__(self, value):
self.value = value
# A sample visitor class that evaluates expressions
class Evaluator(NodeVisitor):
def visit_Number(self, node):
return node.value
def visit_Add(self, node):
yield (yield node.left) + (yield node.right)
def visit_Sub(self, node):
yield (yield node.left) - (yield node.right)
def visit_Mul(self, node):
yield (yield node.left) * (yield node.right)
def visit_Div(self, node):
yield (yield node.left) / (yield node.right)
def visit_Negate(self, node):
yield - (yield node.operand)
>>> a = Number(0)
>>> for n in range(1,100000):
... a = Add(a, Number(n))
...
>>> e = Evaluator()
>>> e.visit(a)
4999950000
循环引用数据结构的内存管理
import weakref
class Node:
def __init__(self, value):
self.value = value
self._parent = None
self.children = []
def __repr__(self):
return 'Node({!r:})'.format(self.value)
# property that manages the parent as a weak-reference
@property
def parent(self):
return None if self._parent is None else self._parent()
@parent.setter
def parent(self, node):
self._parent = weakref.ref(node)
def add_child(self, child):
self.children.append(child)
child.parent = self
>>> root = Node('parent')
>>> c1 = Node('child')
>>> root.add_child(c1)
>>> print(c1.parent)
Node('parent')
>>> del root
>>> print(c1.parent)
None
>>>
让类支持比较操作
from functools import total_ordering
class Room:
def __init__(self, name, length, width):
self.name = name
self.length = length
self.width = width
self.square_feet = self.length * self.width
@total_ordering
class House:
def __init__(self, name, style):
self.name = name
self.style = style
self.rooms = list()
@property
def living_space_footage(self):
return sum(r.square_feet for r in self.rooms)
def add_room(self, room):
self.rooms.append(room)
def __str__(self):
return '{}: {} square foot {}'.format(self.name,
self.living_space_footage,
self.style)
def __eq__(self, other):
return self.living_space_footage == other.living_space_footage
def __lt__(self, other):
return self.living_space_footage < other.living_space_footage
创建缓存实例
class CachedSpamManager:
def __init__(self):
self._cache = weakref.WeakValueDictionary()
def get_spam(self, name):
if name not in self._cache:
temp = Spam._new(name) # Modified creation
self._cache[name] = temp
else:
temp = self._cache[name]
return temp
def clear(self):
self._cache.clear()
class Spam:
def __init__(self, *args, **kwargs):
raise RuntimeError("Can't instantiate directly")
# Alternate constructor
@classmethod
def _new(cls, name):
self = cls.__new__(cls)
self.name = name
return self
元编程
在函数上添加包装器
import time
from functools import wraps
def timethis(func):
'''
Decorator that reports the execution time.
'''
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(func.__name__, end-start)
return result
return wrapper
>>> @timethis
... def countdown(n):
... '''
... Counts down
... '''
... while n > 0:
... n -= 1
...
... countdown(100000)
... countdown.__wrapped__(100000)
带可选参数的装饰器
from functools import wraps, partial
import logging
def logged(func=None, *, level=logging.DEBUG, name=None, message=None):
if func is None:
return partial(logged, level=level, name=name, message=message)
logname = name if name else func.__module__
log = logging.getLogger(logname)
logmsg = message if message else func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
log.log(level, logmsg)
return func(*args, **kwargs)
return wrapper
# Example use
@logged
def add(x, y):
return x + y
@logged(level=logging.CRITICAL, name='example')
def spam():
print('Spam!')
利用装饰器强制函数上的类型检查
from inspect import signature
from functools import wraps
def typeassert(*ty_args, **ty_kwargs):
def decorate(func):
# If in optimized mode, disable type checking
if not __debug__:
return func
# Map function argument names to supplied types
sig = signature(func)
bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
bound_values = sig.bind(*args, **kwargs)
# Enforce type assertions across supplied arguments
for name, value in bound_values.arguments.items():
if name in bound_types:
if not isinstance(value, bound_types[name]):
raise TypeError(
'Argument {} must be {}'.format(name, bound_types[name])
)
return func(*args, **kwargs)
return wrapper
return decorate
>>> @typeassert(int, int)
... def add(x, y):
... return x + y
将装饰器定义为类
import types
from functools import wraps
class Profiled:
def __init__(self, func):
wraps(func)(self)
self.ncalls = 0
def __call__(self, *args, **kwargs):
self.ncalls += 1
return self.__wrapped__(*args, **kwargs)
def __get__(self, instance, cls):
if instance is None:
return self
else:
return types.MethodType(self, instance)
@Profiled
def add(x, y):
return x + y
class Spam:
@Profiled
def bar(self, x):
print(self, x)
>>> add(2, 3)
5
>>> add(4, 5)
9
>>> add.ncalls
2
>>> s = Spam()
>>> s.bar(1)
<__main__.Spam object at 0x10069e9d0> 1
>>> s.bar(2)
<__main__.Spam object at 0x10069e9d0> 2
>>> s.bar(3)
<__main__.Spam object at 0x10069e9d0> 3
>>> Spam.bar.ncalls
3
装饰器为被包装函数增加参数
from functools import wraps
def optional_debug(func):
@wraps(func)
def wrapper(*args, debug=False, **kwargs):
if debug:
print('Calling', func.__name__)
return func(*args, **kwargs)
return wrapper
>>> @optional_debug
... def spam(a,b,c):
... print(a,b,c)
...
>>> spam(1,2,3)
1 2 3
>>> spam(1,2,3, debug=True)
Calling spam
1 2 3
>>>
使用装饰器扩充类的功能
def log_getattribute(cls):
# Get the original implementation
orig_getattribute = cls.__getattribute__
# Make a new definition
def new_getattribute(self, name):
print('getting:', name)
return orig_getattribute(self, name)
# Attach to the class and return
cls.__getattribute__ = new_getattribute
return cls
# Example use
@log_getattribute
class A:
def __init__(self,x):
self.x = x
def spam(self):
pass
>>> a = A(42)
>>> a.x
getting: x
42
>>> a.spam()
getting: spam
>>>
使用元类控制实例的创建
只允许调用类的静态方法
class NoInstances(type):
def __call__(self, *args, **kwargs):
raise TypeError("Can't instantiate directly")
# Example
class Spam(metaclass=NoInstances):
@staticmethod
def grok(x):
print('Spam.grok')
>>> Spam.grok(42)
Spam.grok
>>> s = Spam()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "example1.py", line 7, in __call__
raise TypeError("Can't instantiate directly")
TypeError: Can't instantiate directly
>>>
单例模式
class Singleton(type):
def __init__(self, *args, **kwargs):
self.__instance = None
super().__init__(*args, **kwargs)
def __call__(self, *args, **kwargs):
if self.__instance is None:
self.__instance = super().__call__(*args, **kwargs)
return self.__instance
else:
return self.__instance
# Example
class Spam(metaclass=Singleton):
def __init__(self):
print('Creating Spam')
缓存模式
import weakref
class Cached(type):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__cache = weakref.WeakValueDictionary()
def __call__(self, *args):
if args in self.__cache:
return self.__cache[args]
else:
obj = super().__call__(*args)
self.__cache[args] = obj
return obj
# Example
class Spam(metaclass=Cached):
def __init__(self, name):
print('Creating Spam({!r})'.format(name))
self.name = name
不过想了下,上面的一些实现还是有并发冲突问题,搜索到openstack中有一处单例模式的写法,感觉这个写法更正确。
捕获类的属性定义顺序
from collections import OrderedDict
# A set of descriptors for various types
class Typed:
_expected_type = type(None)
def __init__(self, name=None):
self._name = name
def __set__(self, instance, value):
if not isinstance(value, self._expected_type):
raise TypeError('Expected ' + str(self._expected_type))
instance.__dict__[self._name] = value
class Integer(Typed):
_expected_type = int
class Float(Typed):
_expected_type = float
class String(Typed):
_expected_type = str
# Metaclass that uses an OrderedDict for class body
class OrderedMeta(type):
def __new__(cls, clsname, bases, clsdict):
d = dict(clsdict)
order = []
for name, value in clsdict.items():
if isinstance(value, Typed):
value._name = name
order.append(name)
d['_order'] = order
return type.__new__(cls, clsname, bases, d)
@classmethod
def __prepare__(cls, clsname, bases):
return OrderedDict()
class Structure(metaclass=OrderedMeta):
def as_csv(self):
return ','.join(str(getattr(self,name)) for name in self._order)
# Example use
class Stock(Structure):
name = String()
shares = Integer()
price = Float()
def __init__(self, name, shares, price):
self.name = name
self.shares = shares
self.price = price
>>> s = Stock('GOOG',100,490.1)
>>> s.name
'GOOG'
>>> s.as_csv()
'GOOG,100,490.1'
>>> t = Stock('AAPL','a lot', 610.23)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "dupmethod.py", line 34, in __init__
TypeError: shares expects <class 'int'>
>>>
定义有可选参数的元类
class MyMeta(type):
# Optional
@classmethod
def __prepare__(cls, name, bases, *, debug=False, synchronize=False):
# Custom processing
pass
return super().__prepare__(name, bases)
# Required
def __new__(cls, name, bases, ns, *, debug=False, synchronize=False):
# Custom processing
pass
return super().__new__(cls, name, bases, ns)
# Required
def __init__(self, name, bases, ns, *, debug=False, synchronize=False):
# Custom processing
pass
super().__init__(name, bases, ns)
class Spam(metaclass=MyMeta, debug=True, synchronize=True):
pass
args和kwargs的强制参数签名
from inspect import Signature, Parameter
def make_sig(*names):
parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)
for name in names]
return Signature(parms)
class StructureMeta(type):
def __new__(cls, clsname, bases, clsdict):
clsdict['__signature__'] = make_sig(*clsdict.get('_fields',[]))
return super().__new__(cls, clsname, bases, clsdict)
class Structure(metaclass=StructureMeta):
_fields = []
def __init__(self, *args, **kwargs):
bound_values = self.__signature__.bind(*args, **kwargs)
for name, value in bound_values.arguments.items():
setattr(self, name, value)
# Example
class Stock(Structure):
_fields = ['name', 'shares', 'price']
class Point(Structure):
_fields = ['x', 'y']
在类上强制使用编程规约
from inspect import signature
import logging
class MatchSignaturesMeta(type):
def __init__(self, clsname, bases, clsdict):
super().__init__(clsname, bases, clsdict)
sup = super(self, self)
for name, value in clsdict.items():
if name.startswith('_') or not callable(value):
continue
# Get the previous definition (if any) and compare the signatures
prev_dfn = getattr(sup,name,None)
if prev_dfn:
prev_sig = signature(prev_dfn)
val_sig = signature(value)
if prev_sig != val_sig:
logging.warning('Signature mismatch in %s. %s != %s',
value.__qualname__, prev_sig, val_sig)
# Example
class Root(metaclass=MatchSignaturesMeta):
pass
class A(Root):
def foo(self, x, y):
pass
def spam(self, x, *, z):
pass
# Class with redefined methods, but slightly different signatures
class B(A):
def foo(self, a, b):
pass
def spam(self,x,z):
pass
以编程方式定义类
import operator
import types
import sys
def named_tuple(classname, fieldnames):
# Populate a dictionary of field property accessors
cls_dict = { name: property(operator.itemgetter(n))
for n, name in enumerate(fieldnames) }
# Make a __new__ function and add to the class dict
def __new__(cls, *args):
if len(args) != len(fieldnames):
raise TypeError('Expected {} arguments'.format(len(fieldnames)))
return tuple.__new__(cls, args)
cls_dict['__new__'] = __new__
# Make the class
cls = types.new_class(classname, (tuple,), {},
lambda ns: ns.update(cls_dict))
# Set the module to that of the caller
cls.__module__ = sys._getframe(1).f_globals['__name__']
return cls
types.new_class
的详细用法参考https://docs.python.org/3/library/types.html
在定义的时候初始化类的成员
在类定义时就执行初始化或设置操作是元类的一个典型应用场景。本质上讲,一个元类会在定义时被触发, 这时候你可以执行一些额外的操作。
import operator
class StructTupleMeta(type):
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
for n, name in enumerate(cls._fields):
setattr(cls, name, property(operator.itemgetter(n)))
class StructTuple(tuple, metaclass=StructTupleMeta):
_fields = []
def __new__(cls, *args):
if len(args) != len(cls._fields):
raise ValueError('{} arguments required'.format(len(cls._fields)))
return super().__new__(cls,args)
class Stock(StructTuple):
_fields = ['name', 'shares', 'price']
class Point(StructTuple):
_fields = ['x', 'y']
>>> s = Stock('ACME', 50, 91.1)
>>> s
('ACME', 50, 91.1)
>>> s[0]
'ACME'
>>> s.name
'ACME'
>>> s.shares * s.price
4555.0
避免重复的属性方法
def typed_property(name, expected_type):
storage_name = '_' + name
@property
def prop(self):
return getattr(self, storage_name)
@prop.setter
def prop(self, value):
if not isinstance(value, expected_type):
raise TypeError('{} must be a {}'.format(name, expected_type))
setattr(self, storage_name, value)
return prop
from functools import partial
String = partial(typed_property, expected_type=str)
Integer = partial(typed_property, expected_type=int)
# Example:
class Person:
name = String('name')
age = Integer('age')
def __init__(self, name, age):
self.name = name
self.age = age
定义上下文管理器的简单方法
import time
from contextlib import contextmanager
@contextmanager
def timethis(label):
start = time.time()
try:
yield
finally:
end = time.time()
print('{}: {}'.format(label, end - start))
# Example use
with timethis('counting'):
n = 10000000
while n > 0:
n -= 1
# 这段代码的作用是任何对列表的修改只有当所有代码运行完成并且不出现异常的情况下才会生效。
@contextmanager
def list_transaction(orig_list):
working = list(orig_list)
yield working
orig_list[:] = working
>>> items = [1, 2, 3]
>>> with list_transaction(items) as working:
... working.append(4)
... working.append(5)
...
>>> items
[1, 2, 3, 4, 5]
>>> with list_transaction(items) as working:
... working.append(6)
... working.append(7)
... raise RuntimeError('oops')
...
Traceback (most recent call last):
File "<stdin>", line 4, in <module>
RuntimeError: oops
>>> items
[1, 2, 3, 4, 5]
>>>
在局部变量域中执行代码
>>> def test4():
... a = 13
... loc = { 'a' : a }
... glb = { }
... exec('b = a + 1', glb, loc)
... b = loc['b']
... print(b)
...
>>> test4()
14
>>>
模块与包
控制模块被全部导入的内容
# somemodule.py
def spam():
pass
def grok():
pass
blah = 42
# Only export 'spam' and 'grok'
__all__ = ['spam', 'grok']
使用相对路径名导入包中子模块
mypackage/
__init__.py
A/
__init__.py
spam.py
grok.py
B/
__init__.py
bar.py
如果模块mypackage.A.spam要导入同目录下的模块grok,它应该包括的import语句如下:
# mypackage/A/spam.py
from . import grok
如果模块mypackage.A.spam要导入不同目录下的模块B.bar,它应该使用的import语句如下:
# mypackage/A/spam.py
from ..B import bar
将模块分割成多个文件
mymodule/
__init__.py
a.py
b.py
在a.py文件中插入以下代码:
# a.py
class A:
def spam(self):
print('A.spam')
在b.py文件中插入以下代码:
# b.py
from .a import A
class B(A):
def bar(self):
print('B.bar')
最后,在 __init__.py
中,将2个文件粘合在一起:
# __init__.py
from .a import A
from .b import B
如果按照这些步骤,所产生的包MyModule将作为一个单一的逻辑模块:
>>> import mymodule
>>> a = mymodule.A()
>>> a.spam()
A.spam
>>> b = mymodule.B()
>>> b.bar()
B.bar
>>>
利用命名空间导入目录分散的代码
foo-package/
spam/
blah.py
bar-package/
spam/
grok.py
在这2个目录里,都有着共同的命名空间spam。在任何一个目录里都没有__init__.py
文件。
如果将foo-package和bar-package都加到python模块路径:
>>> import sys
>>> sys.path.extend(['foo-package', 'bar-package'])
>>> import spam.blah
>>> import spam.grok
>>>
两个不同的包目录被合并到一起,你可以导入spam.blah和spam.grok,并且它们能够工作。
在这里工作的机制被称为“包命名空间”的一个特征。从本质上讲,包命名空间是一种特殊的封装设计,为合并不同的目录的代码到一个共同的命名空间。对于大的框架,这可能是有用的,因为它允许一个框架的部分被单独地安装下载。它也使人们能够轻松地为这样的框架编写第三方附加组件和其他扩展。
包命名空间的关键是确保顶级目录中没有__init__.py
文件来作为共同的命名空间。缺失__init__.py
文件使得在导入包的时候会发生有趣的事情:这并没有产生错误,解释器创建了一个由所有包含匹配包名的目录组成的列表。特殊的包命名空间模块被创建,只读的目录列表副本被存储在其__path__
变量中。
>>> import spam
>>> spam.__path__
_NamespacePath(['foo-package/spam', 'bar-package/spam'])
>>> spam.__file__
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'module' object has no attribute '__file__'
>>> spam
<module 'spam' (namespace)>
>>>
运行目录或压缩文件
myapplication/
spam.py
bar.py
grok.py
__main__.py
如果__main__.py
存在,你可以简单地在顶级目录运行Python解释器:
bash % python3 myapplication
如果你将你的代码打包成zip文件,这种技术同样也适用,举个例子:
bash % ls
spam.py bar.py grok.py __main__.py
bash % zip -r myapp.zip *.py
bash % python3 myapp.zip
... output from __main__.py ...
读取位于包中的数据文件
mypackage/
__init__.py
somedata.dat
spam.py
现在假设spam.py文件需要读取somedata.dat文件中的内容。你可以用以下代码来完成:
# spam.py
import pkgutil
data = pkgutil.get_data(__package__, 'somedata.dat')
由此产生的变量是包含该文件的原始内容的字节字符串。
将文件夹加入到sys.path
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)), 'libs'))
通过字符串名导入模块
>>> import importlib
>>> math = importlib.import_module('math')
>>> math.sin(2)
0.9092974268256817
>>> mod = importlib.import_module('urllib.request')
>>> u = mod.urlopen('http://www.python.org')
>>>
import importlib
# Same as 'from . import b'
b = importlib.import_module('.b', __package__)
安装私有的包
python3 setup.py install --user
或者
pip install --user packagename
创建新的Python环境
python3.6 -m venv ./Spam
bash ./Span/bin/activate #进入新的python环境
deactivate #退出新的python环境
分发包
projectname/
README.txt
Doc/
documentation.txt
projectname/
__init__.py
foo.py
bar.py
utils/
__init__.py
spam.py
grok.py
examples/
helloworld.py
...
要编写一个 setup.py
,类似下面这样:
# setup.py
from distutils.core import setup
setup(name='projectname',
version='1.0',
author='Your Name',
author_email='you@youraddress.com',
url='http://www.you.com/projectname',
packages=['projectname', 'projectname.utils'],
)
下一步,就是创建一个 MANIFEST.in
文件,列出所有在你的包中需要包含进来的非源码文件:
# MANIFEST.in
include *.txt
recursive-include examples *
recursive-include Doc *
确保 setup.py
和 MANIFEST.in
文件放在你的包的最顶级目录中。 一旦你已经做了这些,你就可以像下面这样执行命令来创建一个源码分发包了:
% bash python3 setup.py sdist
它会创建一个文件比如”projectname-1.0.zip” 或 “projectname-1.0.tar.gz”, 具体依赖于你的系统平台。
网络与Web编程
作为客户端与HTTP服务交互
from urllib import request, parse
# Extra headers
headers = {
'User-agent' : 'none/ofyourbusiness',
'Spam' : 'Eggs'
}
# Base URL being accessed
url = 'http://httpbin.org/get'
# Dictionary of query parameters (if any)
parms = {
'name1' : 'value1',
'name2' : 'value2'
}
# Encode the query string
querystring = parse.urlencode(parms)
# Make a GET request and read the response
u = request.urlopen(url+'?' + querystring, headers=headers)
resp = u.read()
from urllib import request, parse
# Base URL being accessed
url = 'http://httpbin.org/post'
# Dictionary of query parameters (if any)
parms = {
'name1' : 'value1',
'name2' : 'value2'
}
# Encode the query string
querystring = parse.urlencode(parms)
# Make a POST request and read the response
u = request.urlopen(url, querystring.encode('ascii'), headers=headers)
resp = u.read()
需要交互的服务比上面的例子都要复杂,也许应该去看看 requests 库(https://pypi.python.org/pypi/requests)。
import requests
# Base URL being accessed
url = 'http://httpbin.org/post'
# Dictionary of query parameters (if any)
parms = {
'name1' : 'value1',
'name2' : 'value2'
}
# Extra headers
headers = {
'User-agent' : 'none/ofyourbusiness',
'Spam' : 'Eggs'
}
resp = requests.post(url, data=parms, headers=headers)
# Decoded text returned by the request
text = resp.text
创建TCP服务器
from socketserver import BaseRequestHandler, TCPServer
import traceback
class EchoHandler(BaseRequestHandler):
# Optional settings (defaults shown)
timeout = 5 # Timeout on all socket operations
rbufsize = -1 # Read buffer size
wbufsize = 0 # Write buffer size
disable_nagle_algorithm = False # Sets TCP_NODELAY socket option
def handle(self):
print('Got connection from', self.client_address)
while True:
msg = self.request.recv(8192)
if not msg:
break
self.request.send(msg)
class EchoTCPServer(TCPServer):
allow_reuse_address = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _serve_forever(self):
try:
self.serve_forever()
except KeyboardInterrupt:
pass
except Exception:
traceback.print_exc()
finally:
self.shutdown()
def serve(self, nworkers=15):
from concurrent.futures.thread import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=nworkers, thread_name_prefix='echo_tcp_server') as executor:
for _ in range(nworkers):
executor.submit(lambda : self._serve_forever())
self._serve_forever()
if __name__ == '__main__':
serv = EchoTCPServer(('', 20000), EchoHandler)
serv.serve()
创建UDP服务器
from socketserver import BaseRequestHandler, UDPServer
import time
import traceback
class TimeHandler(BaseRequestHandler):
def handle(self):
print('Got connection from', self.client_address)
# Get message and client socket
msg, sock = self.request
resp = time.ctime()
sock.sendto(resp.encode('ascii'), self.client_address)
class TimeUDPServer(UDPServer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _serve_forever(self):
try:
self.serve_forever()
except KeyboardInterrupt:
pass
except Exception:
traceback.print_exc()
finally:
self.shutdown()
def serve(self, nworkers=15):
from concurrent.futures.thread import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=nworkers, thread_name_prefix='time_udp_server') as executor:
for _ in range(nworkers):
executor.submit(lambda : self._serve_forever())
self._serve_forever()
if __name__ == '__main__':
serv = TimeUDPServer(('', 20000), TimeHandler)
serv.serve()
通过CIDR地址生成对应的IP地址集
>>> import ipaddress
>>> net = ipaddress.ip_network('123.45.67.64/27')
>>> net
IPv4Network('123.45.67.64/27')
>>> for a in net:
... print(a)
...
>>> net6 = ipaddress.ip_network('12:3456:78:90ab:cd:ef01:23:30/125')
>>> net6
IPv6Network('12:3456:78:90ab:cd:ef01:23:30/125')
>>> for a in net6:
... print(a)
...
>>> net.num_addresses
32
>>> net[0]
IPv4Address('123.45.67.64')
>>> net[1]
IPv4Address('123.45.67.65')
>>> net[-1]
IPv4Address('123.45.67.95')
>>> inet = ipaddress.ip_interface('123.45.67.73/27')
>>> inet.network
IPv4Network('123.45.67.64/27')
>>> inet.ip
IPv4Address('123.45.67.73')
>>>
通过XML-RPC实现简单的远程调用
from xmlrpc.server import SimpleXMLRPCServer
import traceback
class KeyValueServer:
_rpc_methods_ = ['get', 'set', 'delete', 'exists', 'keys']
def __init__(self, address):
self._data = {}
self._serv = SimpleXMLRPCServer(address, allow_none=True)
for name in self._rpc_methods_:
self._serv.register_function(getattr(self, name))
def get(self, name):
return self._data[name]
def set(self, name, value):
self._data[name] = value
def delete(self, name):
del self._data[name]
def exists(self, name):
return name in self._data
def keys(self):
return list(self._data)
def serve_forever(self):
try:
self._serv.serve_forever()
except KeyboardInterrupt:
pass
except Exception:
traceback.print_exc()
finally:
self._serv.shutdown()
# Example
if __name__ == '__main__':
serv = KeyValueServer(('', 15000))
nworkers = 15
from concurrent.futures.thread import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=nworkers, thread_name_prefix='keyvalue_rcp_server') as executor:
for _ in range(nworkers):
executor.submit(lambda : serv.serve_forever())
serv.serve_forever()
>>> from xmlrpc.client import ServerProxy
>>> s = ServerProxy('http://localhost:15000', allow_none=True)
>>> s.set('foo', 'bar')
>>> s.set('spam', [1, 2, 3])
在不同的Python解释器之间交互
from multiprocessing.connection import Listener
import traceback
def echo_client(conn):
try:
while True:
msg = conn.recv()
conn.send(msg)
except EOFError:
print('Connection closed')
def serve_forever(serv):
try:
while True:
client = serv.accept()
echo_client(client)
except (KeyboardInterrupt, ConnectionAbortedError):
pass
except Exception:
traceback.print_exc()
finally:
serv.close()
if __name__ == '__main__':
serv = Listener(('', 25000), authkey=b'peekaboo')
nworkers = 15
from concurrent.futures.thread import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=nworkers, thread_name_prefix='multiprocessing_listener') as executor:
for _ in range(nworkers):
executor.submit(serve_forever, serv)
serve_forever(serv)
>>> from multiprocessing.connection import Client
>>> c = Client(('localhost', 25000), authkey=b'peekaboo')
>>> c.send('hello')
>>> c.recv()
'hello'
>>> c.send(42)
>>> c.recv()
42
>>> c.send([1, 2, 3, 4, 5])
>>> c.recv()
[1, 2, 3, 4, 5]
简单的客户端认证
import hmac
import os
def client_authenticate(connection, secret_key):
'''
Authenticate client to a remote service.
connection represents a network connection.
secret_key is a key known only to both client/server.
'''
message = connection.recv(32)
hash = hmac.new(secret_key, message)
digest = hash.digest()
connection.send(digest)
def server_authenticate(connection, secret_key):
'''
Request client authentication.
'''
message = os.urandom(32)
connection.send(message)
hash = hmac.new(secret_key, message)
digest = hash.digest()
response = connection.recv(len(digest))
return hmac.compare_digest(digest,response)
from socket import socket, AF_INET, SOCK_STREAM
secret_key = b'peekaboo'
def echo_handler(client_sock):
if not server_authenticate(client_sock, secret_key):
client_sock.close()
return
while True:
msg = client_sock.recv(8192)
if not msg:
break
client_sock.sendall(msg)
from socket import socket, AF_INET, SOCK_STREAM
secret_key = b'peekaboo'
s = socket(AF_INET, SOCK_STREAM)
s.connect(('localhost', 18000))
client_authenticate(s, secret_key)
s.send(b'Hello World')
resp = s.recv(1024)
在网络服务中加入SSL
import ssl
class SSLMixin:
'''
Mixin class that adds support for SSL to existing servers based
on the socketserver module.
'''
def __init__(self, *args,
keyfile=None, certfile=None, ca_certs=None,
cert_reqs=ssl.CERT_NONE,
**kwargs):
self._keyfile = keyfile
self._certfile = certfile
self._ca_certs = ca_certs
self._cert_reqs = cert_reqs
super().__init__(*args, **kwargs)
def get_request(self):
client, addr = super().get_request()
client_ssl = ssl.wrap_socket(client,
keyfile = self._keyfile,
certfile = self._certfile,
ca_certs = self._ca_certs,
cert_reqs = self._cert_reqs,
server_side = True)
return client_ssl, addr
class SSLSimpleXMLRPCServer(SSLMixin, SimpleXMLRPCServer):
pass
if __name__ == '__main__':
KEYFILE='server_key.pem' # Private key of the server
CERTFILE='server_cert.pem' # Server certificate
kvserv = KeyValueServer(('', 15000),
keyfile=KEYFILE,
certfile=CERTFILE)
kvserv.serve_forever()
理解事件驱动的IO
class EventHandler:
def fileno(self):
'Return the associated file descriptor'
raise NotImplemented('must implement')
def wants_to_receive(self):
'Return True if receiving is allowed'
return False
def handle_receive(self):
'Perform the receive operation'
pass
def wants_to_send(self):
'Return True if sending is requested'
return False
def handle_send(self):
'Send outgoing data'
pass
import select
def event_loop(handlers):
while True:
wants_recv = [h for h in handlers if h.wants_to_receive()]
wants_send = [h for h in handlers if h.wants_to_send()]
can_recv, can_send, _ = select.select(wants_recv, wants_send, [])
for h in can_recv:
h.handle_receive()
for h in can_send:
h.handle_send()
import socket
import time
class UDPServer(EventHandler):
def __init__(self, address):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.bind(address)
def fileno(self):
return self.sock.fileno()
def wants_to_receive(self):
return True
class UDPTimeServer(UDPServer):
def handle_receive(self):
msg, addr = self.sock.recvfrom(1)
self.sock.sendto(time.ctime().encode('ascii'), addr)
class UDPEchoServer(UDPServer):
def handle_receive(self):
msg, addr = self.sock.recvfrom(8192)
self.sock.sendto(msg, addr)
if __name__ == '__main__':
handlers = [ UDPTimeServer(('',14000)), UDPEchoServer(('',15000)) ]
event_loop(handlers)
class TCPServer(EventHandler):
def __init__(self, address, client_handler, handler_list):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
self.sock.bind(address)
self.sock.listen(1)
self.client_handler = client_handler
self.handler_list = handler_list
def fileno(self):
return self.sock.fileno()
def wants_to_receive(self):
return True
def handle_receive(self):
client, addr = self.sock.accept()
# Add the client to the event loop's handler list
self.handler_list.append(self.client_handler(client, self.handler_list))
class TCPClient(EventHandler):
def __init__(self, sock, handler_list):
self.sock = sock
self.handler_list = handler_list
self.outgoing = bytearray()
def fileno(self):
return self.sock.fileno()
def close(self):
self.sock.close()
# Remove myself from the event loop's handler list
self.handler_list.remove(self)
def wants_to_send(self):
return True if self.outgoing else False
def handle_send(self):
nsent = self.sock.send(self.outgoing)
self.outgoing = self.outgoing[nsent:]
class TCPEchoClient(TCPClient):
def wants_to_receive(self):
return True
def handle_receive(self):
data = self.sock.recv(8192)
if not data:
self.close()
else:
self.outgoing.extend(data)
if __name__ == '__main__':
handlers = []
handlers.append(TCPServer(('',16000), TCPEchoClient, handlers))
event_loop(handlers)
对于阻塞或耗时计算的问题可以通过将事件发送个其他单独的线程池来处理。
from concurrent.futures import ThreadPoolExecutor
import os
class ThreadPoolHandler(EventHandler):
def __init__(self, nworkers):
if os.name == 'posix':
self.signal_done_sock, self.done_sock = socket.socketpair()
else:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('127.0.0.1', 0))
server.listen(1)
self.signal_done_sock = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
self.signal_done_sock.connect(server.getsockname())
self.done_sock, _ = server.accept()
server.close()
self.pending = []
self.pool = ThreadPoolExecutor(nworkers)
def fileno(self):
return self.done_sock.fileno()
# Callback that executes when the thread is done
def _complete(self, callback, r):
self.pending.append((callback, r.result()))
self.signal_done_sock.send(b'x')
# Run a function in a thread pool
def run(self, func, args=(), kwargs={},*,callback):
r = self.pool.submit(func, *args, **kwargs)
r.add_done_callback(lambda r: self._complete(callback, r))
def wants_to_receive(self):
return True
# Run callback functions of completed work
def handle_receive(self):
# Invoke all pending callback functions
for callback, result in self.pending:
callback(result)
self.done_sock.recv(1)
self.pending = []
# A really bad Fibonacci implementation
def fib(n):
if n < 2:
return 1
else:
return fib(n - 1) + fib(n - 2)
class UDPFibServer(UDPServer):
def handle_receive(self):
msg, addr = self.sock.recvfrom(128)
n = int(msg)
pool.run(fib, (n,), callback=lambda r: self.respond(r, addr))
def respond(self, result, addr):
self.sock.sendto(str(result).encode('ascii'), addr)
if __name__ == '__main__':
pool = ThreadPoolHandler(16)
handlers = [ pool, UDPFibServer(('',16000))]
event_loop(handlers)
文章作者 徐新杰
上次更新 2017-10-07
许可协议 © Copyright 2020 Jeremy Xu