<
性能测算和计时 | ⽬录 | Numpy介绍 >
More IPython Resources
更多IPython资源
In this chapter, we've just scratched the surface of using IPython to enable data science tasks. Much more information is
available both in print and on the Web, and here we'll list some other resources that you may find helpful.
本章中我们初步讨论了使⽤IPython来解决数据科学任务的⼀些基本内容。更多的内容可以在⽹上或书籍中找到,最后本⼩节来列出其中可
能对你有帮助的⼀些资源。
Web Resources
⽹络资源
The IPython website: The IPython website links to documentation, examples, tutorials, and a variety of other
resources.
The nbviewer website: This site shows static renderings of any IPython notebook available on the internet. The front
page features some example notebooks that you can browse to see what other folks are using IPython for!
A gallery of interesting Jupyter Notebooks: This ever-growing list of notebooks, powered by nbviewer, shows the
depth and breadth of numerical analysis you can do with IPython. It includes everything from short examples and
tutorials to full-blown courses and books composed in the notebook format!
Video Tutorials: searching the Internet, you will find any video-recorded tutorials on IPython. I'd especially
recommend seeking tutorials from the PyCon, SciPy, and PyData conferenes by Fernando Perez and Brian Granger,
two of the primary creators and maintainers of IPython and Jupyter.
官⽹: 在线⽂档、例⼦、教程和其他许多资源。
官⽹: nbviewer⽹站能展⽰互联⽹上的IPython notebook的资源⽂件。⾸⻚展⽰了⼀些notebooks的例⼦,你可以看到其他⼈
是怎样使⽤IPython的。
有趣的Jupyter notebooks展览馆: 这是⼀个不断增加的notebooks列表,由nbviewer进⾏维护,展⽰了许多既有深度⼜有⼴度的
IPython在数值分析中的应⽤。它应有尽有,从简短的例⼦,到稍⻓的教程,直⾄完整的课程和书籍,都是使⽤notebook格式。
视频教程:在互联⽹上可以搜索到很多关于IPython的视频教程。作者特别推荐PyCon,SciPy和PyData学术会上Fernando Perez 和
Brian Granger 做的报告,他们是IPython和Jupyter的主要创始⼈和维护者。
IPython
nbviewer
Books
书籍
Python for Data Analysis: Wes McKinney's book includes a chapter that covers using IPython as a data scientist.
Although much of the material overlaps what we've discussed here, another perspective is always helpful.
Learning IPython for Interactive Computing and Data Visualization: This short book by Cyrille Rossant offers a good
introduction to using IPython for data analysis.
IPython Interactive Computing and Visualization Cookbook: Also by Cyrille Rossant, this book is a longer and more
advanced treatment of using IPython for data science. Despite its name, it's not just about IPython–it also goes into
some depth on a broad range of data science topics.
作者:Wes McKinney,其中有⼀章专⻔讲述使⽤IPython来进⾏数据科学处理。虽然⼤部分的内容可能与本
书我们将要看到的有重复,从另⼀个⻆度进⾏认知永远不是坏事。
Learning IPython for Interactive Computing and Data Visualization: 作者:Cyrille Rossant,⼀本很简短的书籍专⻔介绍使⽤IPython
进⾏数据分析。
IPython Interactive Computing and Visualization Cookbook: 作者:Cyrille Rossant, ⼀本更加详尽的书籍,对于在数据科学领域使⽤
IPython进⾏了深⼊的介绍。虽然名字叫做IPython,实际上内容深度涵盖了数据科学的⼴泛课题。
Python for Data Analysis:
Finally, a reminder that you can find help on your own: IPython's ? -based help functionality (discussed in Help and
Documentation in IPython) can be very useful if you use it well and use it often. As you go through the examples here and
elsewhere, this can be used to familiarize yourself with all the tools that IPython has to offer.
最后还是再次提醒⼀下,当你在使⽤IPython时遇到了困难,不要忘记了IPython本⾝⾃带的帮助⼯具 ? (参⻅IPython帮助和⽂档),当你
经常使⽤它,熟练地掌握它之后,你会发现它能带给你的帮助超出你的预期。当你在本书中或其他资源处查看例⼦的时候,它能让你事半
功倍地熟悉IPython中提供的⼯具和功能。
<
性能测算和计时 | ⽬录 | Numpy介绍 >
<
序⾔ | ⽬录 | IPython帮助和⽂档 >
Open in Colab
IPython: Beyond Normal Python
:超越Python解释器
IPython
There are many options for development environments for Python, and I'm often asked which one I use in my own work.
My answer sometimes surprises people: my preferred environment is IPython plus a text editor (in my case, Emacs or
Atom depending on my mood). IPython (short for Interactive Python) was started in 2001 by Fernando Perez as an
enhanced Python interpreter, and has since grown into a project aiming to provide, in Perez's words, "Tools for the entire
life cycle of research computing." If Python is the engine of our data science task, you might think of IPython as the
interactive control panel.
对于Python来说,开发环境有很多种选择,作者经常会被问到他在⾃⼰⼯作中使⽤哪⼀个。答案有时会让提问者感到吃惊:作者习惯的环
境是IPython 在加上⼀个⽂本编辑器(取决于作者的⼼情,可能是Emacs或者Atom)。IPython(是交互式Python的缩写)是Fernando
Perez在2001年创建的⼀个增强Python解释器的项⽬,⽬前已经发展成为⼀个超级⼯具,⽬标是提供(⽤Perez⾃⼰的话来说)“研究计算
领域完整⽣命周期的⼯具”。如果类⽐Python是我们数据科学的引擎的话,那么你可以认为IPython就是⼀个交互式的控制⾯板。
As well as being a useful interactive interface to Python, IPython also provides a number of useful syntactic additions to
the language; we'll cover the most useful of these additions here. In addition, IPython is closely tied with the Jupyter
project, which provides a browser-based notebook that is useful for development, collaboration, sharing, and even
publication of data science results. The IPython notebook is actually a special case of the broader Jupyter notebook
structure, which encompasses notebooks for Julia, R, and other programming languages. As an example of the
usefulness of the notebook format, look no further than the page you are reading: the entire manuscript for this book was
composed as a set of IPython notebooks.
除了提供Python⼗分⽅便的交互式界⾯外,IPython还提供了⼀些语⾔的语法扩展;我们会在这⾥介绍其中最有⽤的扩展内容。 并且,
IPython紧密的绑定在Jupyter项⽬之上,Jupyter能够提供⼀个浏览器界⾯的笔记本(译者注:后续⼀律称为notebook,不再翻译该名
词),能够⾮常⽅便的开发、合作、分享甚⾄发布数据科学的结果。 IPython notebook是Jupyter这个庞⼤项⽬中的Python部分,后者希望
为Julia、R和其他编程语⾔都能提供notebook功能。 本书就可以作为notebook格式的⼀个有⼒证明:所有本书的内容都是使⽤IPython
notebook编写的。
IPython is about using Python effectively for interactive scientific and data-intensive computing. This chapter will start by
stepping through some of the IPython features that are useful to the practice of data science, focusing especially on the
syntax it offers beyond the standard features of Python. Next, we will go into a bit more depth on some of the more useful
"magic commands" that can speed-up common tasks in creating and using data science code. Finally, we will touch on
some of the features of the notebook that make it useful in understanding data and sharing results.
的⽬标是让科学和数据计算在Python中更加有效和具有交互性。本章将会介绍许多对于数据科学实践来说⾮常有⽤的IPython特
性,尤其聚焦在它提供在Python标准之外的语法扩展。 然后,我们将会进⼀步深⼊到⼀些有⽤的“魔术命令”中,这些命令能够提⾼你在编
写和使⽤数据科学代码的时候的⽣产效率。 最后,我们将接触到在notebook当中有⽤的数据分析和分享结果的特性。
IPython
Shell or Notebook?
Shell
还是 Notebook?
There are two primary means of using IPython that we'll discuss in this chapter: the IPython shell and the IPython
notebook. The bulk of the material in this chapter is relevant to both, and the examples will switch between them
depending on what is most convenient. In the few sections that are relevant to just one or the other, we will explicitly state
that fact. Before we start, some words on how to launch the IPython shell and IPython notebook.
提供了两种主要的使⽤⽅法,Shell和Notebook。本章将会都使⽤到,例⼦将会根据最⽅便的⽅式切换。如果某些⼩节只会使⽤到
其中⼀个,作者会明确指出。在我们开始之前,我们先来学习如何启动shell和notebook。
IPython
Launching the IPython Shell
启动IPython Shell
This chapter, like most of this book, is not designed to be absorbed passively. I recommend that as you read through it,
you follow along and experiment with the tools and syntax we cover: the muscle-memory you build through doing this will
be far more useful than the simple act of reading about it. Start by launching the IPython interpreter by typing ipython
on the command-line; alternatively, if you've installed a distribution like Anaconda or EPD, there may be a launcher
specific to your system (we'll discuss this more fully in Help and Documentation in IPython).
本章和本书⼤多数章节⼀样,不是希望你只是被动学习。作者建议当你阅读的时候,能够跟着内容进⾏实践,对我们介绍的⼯具和语法进
⾏试验:在此过程中你将会获得肌⾁记忆,这远⽐简单的阅读牢靠的多。 启动IPython解释器,你只需要在命令⾏上输⼊ ipython 即可;
同样的,如果你安装的是Anaconda或者EPD这样的发⾏版,那么你的操作系统上可能会有相应的启动图标(我们会在IPython帮助和⽂档
中更详细的讨论)。
Once you do this, you should see a prompt like the following:
IPython 4.0.1 -- An enhanced Interactive Python.
?
-> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help
-> Python's own help system.
object?
-> Details about 'object', use 'object??' for extra details.
In [1]:
With that, you're ready to follow along.
当你输⼊命令后,你应该会看到如下的⼀个提⽰符:
IPython 4.0.1 -- An enhanced Interactive Python.
?
-> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help
-> Python's own help system.
object?
-> Details about 'object', use 'object??' for extra details.
In [1]:
然后,你就可以接着往下进⾏了。
Launching the Jupyter Notebook
启动Jupyter Notebook
The Jupyter notebook is a browser-based graphical interface to the IPython shell, and builds on it a rich set of dynamic
display capabilities. As well as executing Python/IPython statements, the notebook allows the user to include formatted
text, static and dynamic visualizations, mathematical equations, JavaScript widgets, and much more. Furthermore, these
documents can be saved in a way that lets other people open them and execute the code on their own systems.
是⼀个浏览器图形界⾯的IPython shell,上⾯构建了⼀整套丰富的动态展⽰功能。 除了能够执⾏Python/IPython代码,
还允许⽤⼾书写格式化的⽂本,静态和动态的展⽰数学公式,JavaScript组件和其他很多功能。然后这些⽂档能被保存成⼀种能
让其他⼈在他们⾃⼰的系统中打开和执⾏的⽂件格式。
Jupyter notebook
notebook
Though the IPython notebook is viewed and edited through your web browser window, it must connect to a running
Python process in order to execute code. This process (known as a "kernel") can be started by running the following
command in your system shell:
$ jupyter notebook
虽然IPython notebook在你的浏览器窗⼝中展⽰和编辑,但是它需要连接到⼀个执⾏中的Python进程才能真正执⾏代码。这个进程(被称
为“kernel”)可以在命令⾏中使⽤下⾯的命令启动:
$ jupyter notebook
This command will launch a local web server that will be visible to your browser. It immediately spits out a log showing
what it is doing; that log will look something like this:
$ jupyter notebook
[NotebookApp] Serving notebooks from local directory: /Users/jakevdp/PythonDataScienceHandbook
[NotebookApp] 0 active kernels
[NotebookApp] The IPython Notebook is running at: http://localhost:8888/
[NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confi
rmation).
这个命令会启动⼀个web服务器让你的浏览器访问。它会⽴刻在你的控制台打印出⽇志,⽇志的格式类似下⾯:
$ jupyter notebook
[NotebookApp] Serving notebooks from local directory: /Users/jakevdp/PythonDataScienceHandbook
[NotebookApp] 0 active kernels
[NotebookApp] The IPython Notebook is running at: http://localhost:8888/
[NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confi
rmation).
Upon issuing the command, your default browser should automatically open and navigate to the listed local URL; the
exact address will depend on your system. If the browser does not open automatically, you can open a window and
manually open this address (http://localhost:8888/ in this example).
输⼊上述命令之后,你系统的默认浏览器应该会⾃动打开然后指向本地的地址;完整的地址取决于你的系统。如果你的浏览器没有⾃动打
开,你可以⼿动打开它并输⼊地址(本例中是http://localhost:8888/ )。
<
序⾔ | ⽬录 | IPython帮助和⽂档 >
Open in Colab
< IPython
:超越Python解释器 | ⽬录 | IPython Shell中的键盘快捷键 >
Open in Colab
Help and Documentation in IPython
帮助和⽂档
IPython
If you read no other section in this chapter, read this one: I find the tools discussed here to be the most transformative
contributions of IPython to my daily workflow.
如果本章内容让你仅挑选⼀节来阅读的话,请你读这⼀节:本节讨论的⼯具对于其⽇常⼯作中使⽤IPython有着极⼤的帮助。
When a technologically-minded person is asked to help a friend, family member, or colleague with a computer problem,
most of the time it's less a matter of knowing the answer as much as knowing how to quickly find an unknown answer. In
data science it's the same: searchable web resources such as online documentation, mailing-list threads, and
StackOverflow answers contain a wealth of information, even (especially?) if it is a topic you've found yourself searching
before. Being an effective practitioner of data science is less about memorizing the tool or command you should use for
every possible situation, and more about learning to effectively find the information you don't know, whether through a
web search engine or another means.
当⼀个技术⼈员在计算机问题上被朋友、家⼈或同事请求帮助的时候,⼤多数情况下,直接知道答案和知道如何迅速的找到答案是⼀样
的。在数据科学领域也是同样的情况:可以搜索到的⽹络资源例如在线⽂档、邮件列表以及StackOverflow上的问题答案都含有丰富的信
息,特别是这⽅⾯的内容你之前已经⾃⼰搜索过答案的情况下。作为⼀个⾼效的数据科学⼈员,并不需要你能记得每⼀个可能场景需要⽤
到的⼯具或命令,更重要的是你能在不知道答案的情况下迅速地找到信息,⽆论是通过搜索引擎还是其他的⽅法。
One of the most useful functions of IPython/Jupyter is to shorten the gap between the user and the type of documentation
and search that will help them do their work effectively. While web searches still play a role in answering complicated
questions, an amazing amount of information can be found through IPython alone. Some examples of the questions
IPython can help answer in a few keystrokes:
How do I call this function? What arguments and options does it have?
What does the source code of this Python object look like?
What is in this package I imported? What attributes or methods does this object have?
中最优秀的特性就是弥合了⽤⼾与⽂档以及搜索它们的⽅式之间的鸿沟,使得⽤⼾能够提⾼他们的⼯作效率。⽹络搜索依
旧是查找复杂问题答案不可或缺的⽅式,但是IPython本⾝就已经提供了许多相关的信息。很多的问题通过在IPython中输⼊⼏个字符就可
以找到答案,例如:
我该如何调⽤这个函数?它需要怎样的参数和选项?
这个Python对象的源码是怎样的?
我载⼊的这个包⾥⾯有哪些内容?这个对象有哪些属性或⽅法?
IPython/Jupyter
Here we'll discuss IPython's tools to quickly access this information, namely the ? character to explore documentation,
the ?? characters to explore source code, and the Tab key for auto-completion.
这⾥我们会讨论IPython中获取这些信息的⽅式,即, ? 符号来查看⽂档, ?? 符号来查看源码以及使⽤制表符进⾏⾃动补全。
Accessing Documentation with ?
使⽤ ? 获取⽂档
The Python language and its data science ecosystem is built with the user in mind, and one big part of that is access to
documentation. Every Python object contains the reference to a string, known as a doc string, which in most cases will
contain a concise summary of the object and how to use it. Python has a built-in help() function that can access this
information and prints the results. For example, to see the documentation of the built-in len function, you can do the
following:
语⾔和数据科学⽣态系统始终将⽤⼾需求放在重要位置,其中⼀⼤部分就是获取⽂档。每⼀个Python对象都含有⼀个字符串的说
明,称为⽂档字符串,它是该对象的简要概括以及如何使⽤的信息。Python中有內建的 help() 函数能够获取这些信息并打印出来。例
如,要获得内建函数 len() 的⽂档字符串,你可以这样做:
Python
In [1]: help(len)
Help on built-in function len in module builtins:
len(...)
len(object) -> integer
Return the number of items of a sequence or mapping.
Depending on your interpreter, this information may be displayed as inline text, or in some separate pop-up window.
取决于你的解释器,这个信息可能出现在内嵌的输出⽂本中,也可能出现在⼀个弹出的窗⼝中。
Because finding help on an object is so common and useful, IPython introduces the ? character as a shorthand for
accessing this documentation and other relevant information:
因为查找⼀个对象的帮助是如此普遍和有⽤,IPython引⼊了 ? 符号来简化 help() 內建函数的操作:
In [2]: len?
Type:
builtin_function_or_method
String form: <built-in function len>
Namespace:
Python builtin
Docstring:
len(object) -> integer
Return the number of items of a sequence or mapping.
This notation works for just about anything, including object methods:
这种写法基本上可以应⽤于任何对象,包括对象的⽅法:
In [3]: L = [1, 2, 3]
In [4]: L.insert?
Type:
builtin_function_or_method
String form: <built-in method insert of list object at 0x1024b8ea8>
Docstring:
L.insert(index, object) -- insert object before index
or even objects themselves, with the documentation from their type:
或者对象本⾝,返回的将会是这种类型的⽂档:
In [5]: L?
Type:
list
String form: [1, 2, 3]
Length:
3
Docstring:
list() -> new empty list
list(iterable) -> new list initialized from iterable's items
Importantly, this will even work for functions or other objects you create yourself! Here we'll define a small function with a
docstring:
更为重要的是,这个符号还能应⽤在你⾃⼰创建的对象和其他对象上,下⾯我们来定义⼀个很简单的带有docstring的函数:
In [6]: def square(a):
....:
"""Return the square of a."""
....:
return a ** 2
....:
Note that to create a docstring for our function, we simply placed a string literal in the first line. Because doc strings are
usually multiple lines, by convention we used Python's triple-quote notation for multi-line strings.
从上例中我们可以看到,我们可以在函数的第⼀⾏放置⼀个字符串来实现docstring。因为docstring通常是多⾏的⽂本,习惯上我们会使⽤
Python的三引号记号来代表它。
Now we'll use the ? mark to find this doc string:
下⾯我们就可以使⽤ ? 符号来查找这个docstring:
In [7]: square?
Type:
function
String form: <function square at 0x103713cb0>
Definition: square(a)
Docstring:
Return the square of a.
This quick access to documentation via docstrings is one reason you should get in the habit of always adding such inline
documentation to the code you write!
因为有了这么简便的查找docstring的⽅式,因此这也是你需要养成在每个对象中加⼊docstring⽂档习惯的原因之⼀。
Accessing Source Code with ??
使⽤ ?? 获取源代码
Because the Python language is so easily readable, another level of insight can usually be gained by reading the source
code of the object you're curious about. IPython provides a shortcut to the source code with the double question mark
( ?? ):
因为Python⾮常容易阅读,所以你也可以通过阅读你感兴趣的对象的源代码来获取你需要的帮助信息。IPython提供了使⽤双问号 ?? 的⽅
式获取对象源代码:
In [8]: square??
Type:
function
String form: <function square at 0x103713cb0>
Definition: square(a)
Source:
def square(a):
"Return the square of a"
return a ** 2
For simple functions like this, the double question-mark can give quick insight into the under-the-hood details.
对于像这样简单的函数来说,双问号语法可以快速地给你提供对象的内部详细实现机制。
If you play with this much, you'll notice that sometimes the ?? suffix doesn't display any source code: this is generally
because the object in question is not implemented in Python, but in C or some other compiled extension language. If this
is the case, the ?? suffix gives the same output as the ? suffix. You'll find this particularly with many of Python's built-in
objects and types, for example len from above:
如果你经常使⽤ ?? ,你会发现有些情况下 ?? 并不能显⽰任何源代码:这主要是因为某些对象并不是⽤Python语⾔实现的,⽽是使⽤C
或者其他⼀个需编译的语⾔实现的。如果出现这种情况,那么 ?? 的作⽤就与 ? ⼀致。你可以在很多Python的內建对象和类型中发现这个
问题,例如前⾯的 len() 函数:
In [9]: len??
Type:
builtin_function_or_method
String form: <built-in function len>
Namespace:
Python builtin
Docstring:
len(object) -> integer
Return the number of items of a sequence or mapping.
Using ? and/or ?? gives a powerful and quick interface for finding information about what any Python function or
module does.
使⽤ ? 或 ?? 可以很快的让你查找到Python函数或模块的⽂档或代码,这是很强⼤的⼀个功能。
Exploring Modules with Tab-Completion
使⽤Tab补全来探索模块
IPython's other useful interface is the use of the tab key for auto-completion and exploration of the contents of objects,
modules, and name-spaces. In the examples that follow, we'll use <TAB> to indicate when the Tab key should be
pressed.
的另⼀个有⽤的功能是使⽤制表符 tab 来进⾏⾃动补全以及对Python对象、模块和命名空间进⾏探索。在下⾯的例⼦中,我们将
会使⽤
来表⽰需要敲击制表符键的地⽅。
IPython
<TAB>
Tab-completion of object contents
对象内容的Tab补全
Every Python object has various attributes and methods associated with it. Like with the help function discussed
before, Python has a built-in dir function that returns a list of these, but the tab-completion interface is much easier to
use in practice. To see a list of all available attributes of an object, you can type the name of the object followed by a
period (" . ") character and the Tab key:
所有的Python对象都有着⾃⼰的属性和⽅法。就像Python有着內建的 help 函数⼀样,Python也有⼀个內建的 dir 函数,会列⽰出对象
的属性和⽅法的列表,但是在IPython中,使⽤tab补全会更加简单。查看⼀个对象所有可⽤的属性和⽅法,你需要输⼊对象的名称和后⾯
的⼀个点( . ),然后点击Tab键:
In [10]: L.<TAB>
L.append
L.copy
L.clear
L.count
L.extend
L.index
L.insert
L.pop
L.remove
L.reverse
L.sort
To narrow-down the list, you can type the first character or several characters of the name, and the Tab key will find the
matching attributes and methods:
如果希望缩⼩列表范围,你可以输⼊属性或⽅法的头⼏个字⺟,⾃动补全会找到能匹配这些字符的属性和⽅法:
In [10]: L.c<TAB>
L.clear L.copy
L.count
In [10]: L.co<TAB>
L.copy
L.count
If there is only a single option, pressing the Tab key will complete the line for you. For example, the following will instantly
be replaced with L.count :
如果能匹配的属性或⽅法只有⼀个,那么键⼊Tab的时候,IPython会⾃动将整个属性或⽅法的名称补充完整。例如,下⾯的输⼊会最终产
⽣ L.count :
In [10]: L.cou<TAB>
Though Python has no strictly-enforced distinction between public/external attributes and private/internal attributes, by
convention a preceding underscore is used to denote such methods. For clarity, these private methods and special
methods are omitted from the list by default, but it's possible to list them by explicitly typing the underscore:
虽然Python并没有明确强制定义公有的和私有的属性和⽅法,但是习惯上,如果属性或⽅法名称是以下划线开头的话,就被认为是私有
的。为了是列表清晰,默认的情况下,tab补全列表会忽略下划线开头的属性和⽅法,但如果你需要显⽰它们,你可以在点后⾯明确输⼊⼀
个下划线来展⽰:
In [10]: L._<TAB>
L.__add__
L.__class__
L.__gt__
L.__hash__
L.__reduce__
L.__reduce_ex__
For brevity, we've only shown the first couple lines of the output. Most of these are Python's special double-underscore
methods (often nicknamed "dunder" methods).
为了简洁起⻅,我们这⾥只展⽰了前⾯两⾏内容。⼤部分这种命名都属于Python特殊的双下划线⽅法。
Tab completion when importing
当载⼊模块是使⽤tab补全
Tab completion is also useful when importing objects from packages. Here we'll use it to find all possible imports in the
itertools package that start with co :
Tab
补全也同样适⽤于载⼊模块的时候。下⾯我们试着从 itertools 包中查找所有以 co 开头的内容:
In [10]: from itertools import co<TAB>
combinations
compress
combinations_with_replacement count
Similarly, you can use tab-completion to see which imports are available on your system (this will change depending on
which third-party scripts and modules are visible to your Python session):
同样的,你也可以使⽤tab补全⽅式来查看系统中所有可以被载⼊的模块(根据你当前Python进程的上下⽂不同,第三⽅包和模块的可⻅性
也会不同,因此列⽰的内容也会有差别):
In [10]: import <TAB>
Display all 399 possibilities? (y or n)
Crypto
dis
py_compile
Cython
distutils
pyclbr
...
...
...
difflib
pwd
zmq
In [10]: import h<TAB>
hashlib
hmac
heapq
html
http
husl
(Note that for brevity, I did not print here all 399 importable packages and modules on my system.)
(当然为了简洁起⻅,这⾥肯定没有列出所有399个可⽤的包和模块出来)
Beyond tab completion: wildcard matching
Tab
补全进阶:通配符匹配
Tab completion is useful if you know the first few characters of the object or attribute you're looking for, but is little help if
you'd like to match characters at the middle or end of the word. For this use-case, IPython provides a means of wildcard
matching for names using the * character.
补全对于你知道对象或属性的头⼏个字⺟的情况下⾮常有效,但是如果你只记得中间或末尾处的字符时,tab补全就⽆法发挥了。对于
这种情况,IPython提供了⼀种使⽤通配符 * 来匹配内容的⽅法。
Tab
For example, we can use this to list every object in the namespace that ends with Warning :
例如,我们可以使⽤它列出任何末尾为 Warning 的对象:
In [10]: *Warning?
BytesWarning
DeprecationWarning
FutureWarning
ImportWarning
PendingDeprecationWarning
ResourceWarning
RuntimeWarning
SyntaxWarning
UnicodeWarning
UserWarning
Warning
Notice that the * character matches any string, including the empty string.
这⾥的 * 号能匹配任何字符串,包括空字符串。
Similarly, suppose we are looking for a string method that contains the word find somewhere in its name. We can
search for it this way:
类似的,如果我们希望找到所有名称中含有 find 字符串的对象内容,我们可以这样做:
In [10]: str.*find*?
str.find
str.rfind
I find this type of flexible wildcard search can be very useful for finding a particular command when getting to know a new
package or reacquainting myself with a familiar one.
作者发现这种通配符的⽅式对于在⼀个新的包中找到你想要的内容,或者你忘记了⼀个熟悉的包中的内容是特别有效。
< IPython
:超越Python解释器 | ⽬录 | IPython Shell中的键盘快捷键 >
Open in Colab
< IPython
帮助和⽂档 | ⽬录 | IPython魔术命令 >
Open in Colab
Keyboard Shortcuts in the IPython Shell
中的键盘快捷键
IPython Shell
If you spend any amount of time on the computer, you've probably found a use for keyboard shortcuts in your workflow.
Most familiar perhaps are the Cmd-C and Cmd-V (or Ctrl-C and Ctrl-V) for copying and pasting in a wide variety of
programs and systems. Power-users tend to go even further: popular text editors like Emacs, Vim, and others provide
users an incredible range of operations through intricate combinations of keystrokes.
如果你已经使⽤计算机⼀段时间了,你会发现键盘快捷键在你的⼯作中经常会被⽤到。这⾥⾯最常⽤的莫过于Cmd-C和Cmd-V(或者CtrlC和Ctrl-V),⽤来复制和粘贴。熟练的⽤⼾可能⾛得更远:流⾏的⽂本编辑器如Emacs、Vim等会给⽤⼾提供很多的组合快捷键。
The IPython shell doesn't go this far, but does provide a number of keyboard shortcuts for fast navigation while typing
commands. These shortcuts are not in fact provided by IPython itself, but through its dependency on the GNU Readline
library: as such, some of the following shortcuts may differ depending on your system configuration. Also, while some of
these shortcuts do work in the browser-based notebook, this section is primarily about shortcuts in the IPython shell.
没有像上述的⽂本编辑器那么复杂,但是也提供了不少的快捷键能让⽤⼾在输⼊命令的时候提⾼⼯作效率。这些快捷键实际
上并不是
本⾝提供的,是基于它所依赖的GNU Readline库提供的:因此,下⾯介绍的某些快捷键可能会根据你的系统设置不同⽽
发⽣改变。虽然本⼩节介绍的⼀些快捷键也在浏览器中的notebook应⽤,但是⽬前我们聚焦在IPython shell上。
IPython shell
IPython
Once you get accustomed to these, they can be very useful for quickly performing certain commands without moving
your hands from the "home" keyboard position. If you're an Emacs user or if you have experience with Linux-style shells,
the following will be very familiar. We'll group these shortcuts into a few categories: navigation shortcuts, text entry
shortcuts, command history shortcuts, and miscellaneous shortcuts.
⼀旦你习惯了这些快捷键,你会发现它们能⼤⼤提⾼你在shell中输⼊命令的效率,甚⾄在你的⼿指不需要离开键盘主位置的情况下。如果
你是⼀个Emacs编辑器的⽤⼾,或者是⼀个Linux shell的⽤⼾,以下内容对你来说不会陌⽣。我们将这些快捷键分为⼏组:导航快捷键, ⽂
字输⼊快捷键, 命令历史实现快捷键以及杂项快捷键。
Navigation shortcuts
导航快捷键
While the use of the left and right arrow keys to move backward and forward in the line is quite obvious, there are other
options that don't require moving your hands from the "home" keyboard position:
Keystroke
Action
Ctrl-a
Move cursor to the beginning of the line
Ctrl-e
Move cursor to the end of the line
Ctrl-b or the left arrow key
Move cursor back one character
Ctrl-f or the right arrow key
Move cursor forward one character
显然使⽤左右箭头键来在⼀⾏命令中前后移动是很明显的,但是也有其他的选择让你⽆需将⼿移动到主键盘位置之外:
按键
动作
Ctrl-a 将光标移动到本⾏开始位置
Ctrl-e 将光标移动到本⾏结束位置
Ctrl-b 或者 左箭头
将光标向左移动⼀个字符
Ctrl-f 或者 右箭头
将光标向右移动⼀个字符
译者注:如果你熟悉BASH,这四个快捷键⼀定不陌⽣。
Text Entry Shortcuts
⽂字输⼊快捷键
While everyone is familiar with using the Backspace key to delete the previous character, reaching for the key often
requires some minor finger gymnastics, and it only deletes a single character at a time. In IPython there are several
shortcuts for removing some portion of the text you're typing. The most immediately useful of these are the commands to
delete entire lines of text. You'll know these have become second-nature if you find yourself using a combination of Ctrl-b
and Ctrl-d instead of reaching for Backspace to delete the previous character!
Keystroke
Action
Backspace key
Delete previous character in line
Ctrl-d
Delete next character in line
Ctrl-k
Cut text from cursor to end of line
Ctrl-u
Cut text from beginning of line to cursor
Ctrl-y
Yank (i.e. paste) text that was previously cut
Ctrl-t
Transpose (i.e., switch) previous two characters
我们都知道使⽤回退键可以删除前⼀个字符,去按下这个键有时也需要将⼿移出主键盘位置,⽽且这个键每次只能删除⼀个字符。在
IPython中,有⼀些快捷键可以删除部分你正在输⼊的⽂字。这其中最有⽤的可能就是删除整⾏⽂字。当你熟练之后,你也可能本能的使⽤
Ctrl-b和Ctrl-d来代替回退键。
按键
回退键
Ctrl-d 或者 删除键
动作
删除光标前⼀个字符
删除光标所在字符
Ctrl-k
剪切光标所在位置直⾄末尾的字符
Ctrl-u 剪切开头直⾄光标所在前⼀个位置的字符
Ctrl-y
粘贴字符到光标所在位置
Ctrl-t 交换光标前⼀位置和光标所在位置的字符
Command History Shortcuts
命令历史快捷键
Perhaps the most impactful shortcuts discussed here are the ones IPython provides for navigating the command history.
This command history goes beyond your current IPython session: your entire command history is stored in a SQLite
database in your IPython profile directory. The most straightforward way to access these is with the up and down arrow
keys to step through the history, but other options exist as well:
Keystroke
Action
Ctrl-p (or the up arrow key)
Access previous command in history
Ctrl-n (or the down arrow key)
Access next command in history
Ctrl-r
Reverse-search through command history
本⼩节讨论的快捷键中,可能提供在命令历史中导航的快捷键最令⼈震撼。命令的历史不仅仅是当前的IPython会话有效,所有的命令历史
都会被记录到⼀个SQLite的数据库中,保存在你的IPython配置⽬录下。最直接使⽤命令历史的⽅法就是向上的箭头和向下的箭头,下表列
⽰了命令历史的快捷键:
按键
动作
Ctrl-p 或者 上箭头 获取上⼀条命令历史
Ctrl-n 或者 下箭头 获取下⼀条命令历史
Ctrl-r
反向搜索命令历史
The reverse-search can be particularly useful. Recall that in the previous section we defined a function called square .
Let's reverse-search our Python history from a new IPython shell and find this definition again. When you press Ctrl-r in
the IPython terminal, you'll see the following prompt:
反向搜索有时会⾮常有⽤。回忆⼀下上⼀节中我们定义了⼀个函数名叫 square 。让我们在IPython shell中使⽤命令历史回查这个函数的
定义。当我们在IPython终端中按下Ctrl-r时,你会看到如下的提⽰符:
In [1]:
(reverse-i-search)`':
If you start typing characters at this prompt, IPython will auto-fill the most recent command, if any, that matches those
characters:
如果你在这个提⽰符下输⼊,IPython会根据你输⼊的部分内容⾃动补充最近使⽤的命令:
In [1]:
(reverse-i-search)`sqa': square??
At any point, you can add more characters to refine the search, or press Ctrl-r again to search further for another
command that matches the query. If you followed along in the previous section, pressing Ctrl-r twice more gives:
在这种情况下,你还可以输⼊更多的字符来精准搜索,或者继续按键Ctrl-r来查找下⼀个(更早)能匹配的命令。如果你输⼊了上例中的
sqa ,再按下⼀次Ctrl-r会得到:
In [1]:
(reverse-i-search)`sqa': def square(a):
"""Return the square of a"""
return a ** 2
Once you have found the command you're looking for, press Return and the search will end. We can then use the
retrieved command, and carry-on with our session:
⼀旦你找到了你需要的命令,敲击回⻋将结束反向搜索。然后你就能使⽤找到的命令继续了:
In [1]: def square(a):
"""Return the square of a"""
return a ** 2
In [2]: square(2)
Out[2]: 4
Note that Ctrl-p/Ctrl-n or the up/down arrow keys can also be used to search through history, but only by matching
characters at the beginning of the line. That is, if you type def and then press Ctrl-p, it would find the most recent
command (if any) in your history that begins with the characters def .
注意Ctrl-p/Ctrl-n或者上箭头/下箭头键也可以⽤来进⾏反向搜索,但是仅能匹配命令开头的那些字符。意思是,如果你输⼊ def 然后键⼊
Ctrl-p或者向上箭头,IPython会试图寻找最新的⼀条命令历史,并且以 def 开头。
Miscellaneous Shortcuts
杂项快捷键
Finally, there are a few miscellaneous shortcuts that don't fit into any of the preceding categories, but are nevertheless
useful to know:
Keystroke
Action
Ctrl-l
Clear terminal screen
Ctrl-c
Interrupt current Python command
Ctrl-d
Exit IPython session
最后,还有⼀些杂项的快捷键不属于上述的组别中,但是也挺有⽤:
按键
动作
Ctrl-l
清除终端窗⼝内容
Ctrl-c 终⽌当前的Python语句执⾏
Ctrl-d
退出IPython会话
The Ctrl-c in particular can be useful when you inadvertently start a very long-running job.
当你在不⼩⼼运⾏了⼀个⾮常花时间(或者⽆限循环)的任务时会很有⽤。
Ctrl-c
While some of the shortcuts discussed here may seem a bit tedious at first, they quickly become automatic with practice.
Once you develop that muscle memory, I suspect you will even find yourself wishing they were available in other
contexts.
虽然本节列出的快捷键看起来很冗余,但是很快你会发现它们在实践中的作⽤。⼀旦你形成了肌⾁记忆,你甚⾄会希望在其他环境中也能
使⽤它们。
< IPython
帮助和⽂档 | ⽬录 | IPython魔术命令 >
Open in Colab
< IPython Shell
中的键盘快捷键 | ⽬录 | 输⼊输出历史 >
Open in Colab
IPython Magic Commands
魔术命令
IPython
The previous two sections showed how IPython lets you use and explore Python efficiently and interactively. Here we'll
begin discussing some of the enhancements that IPython adds on top of the normal Python syntax. These are known in
IPython as magic commands, and are prefixed by the % character. These magic commands are designed to succinctly
solve various common problems in standard data analysis. Magic commands come in two flavors: line magics, which are
denoted by a single % prefix and operate on a single line of input, and cell magics, which are denoted by a double %%
prefix and operate on multiple lines of input. We'll demonstrate and discuss a few brief examples here, and come back to
more focused discussion of several useful magic commands later in the chapter.
前两⼩节展⽰了怎样使⽤IPython,令你在其中执⾏Python代码更加有效和具有交互性。现在我们要开始讨论⼀些IPython增强的语⾔特
性。这些特性被称为IPython的魔术命令,它们都是以 % 字符开头的。这些魔术命令被设计⽤来简洁地实现很多通⽤的标准数据科学问
题。魔术命令分成两种模式:⾏魔术,以⼀个 % 开头,是对于⼀⾏的输⼊进⾏魔术处理的;另⼀种是单元格魔术,以两个 %% 开头,是对
于多⾏的输⼊进⾏魔术处理的。本节我们会展⽰和讨论⼀些例⼦,然后本章后续⼩节会对部分有⽤的魔术命令进⾏详细的讨论。
Pasting Code Blocks: %paste and %cpaste
粘贴代码块: %paste 和 %cpaste
When working in the IPython interpreter, one common gotcha is that pasting multi-line code blocks can lead to
unexpected errors, especially when indentation and interpreter markers are involved. A common case is that you find
some example code on a website and want to paste it into your interpreter. Consider the following simple function:
当使⽤IPython解释器时,我们会遇到⼀个坑,就是粘贴多⾏代码块是会出现很多意料之外的错误,尤其是当存在缩进和提⽰符的情况下。
其中⼀个常⻅的情况就是当你在⽹上找到⼀些⽰例代码,然后想将它们粘贴到你的解释器中。例如下⾯这个简单的函数:
>>> def donothing(x):
...
return x
The code is formatted as it would appear in the Python interpreter, and if you copy and paste this directly into IPython you
get an error:
这段代码在Python解释器中就会像上⾯那样展⽰,但是如果你采⽤通常的复制粘贴⼤法将它们粘贴到IPython的时候,错误就发⽣了:
In [2]: >>> def donothing(x):
...:
...
return x
...:
File "<ipython-input-20-5a66c8964687>", line 2
...
return x
^
SyntaxError: invalid syntax
In the direct paste, the interpreter is confused by the additional prompt characters. But never fear–IPython's %paste
magic function is designed to handle this exact type of multi-line, marked-up input:
在直接粘贴的情况下,解释器被额外的提⽰符号搞蒙了。不怕,IPyton的 %paste 魔术命令是专⻔为了处理这种情况(多⾏代码块,带提
⽰符号)设计的:
In [3]: %paste
>>> def donothing(x):
...
return x
## -- End pasted text -The %paste command both enters and executes the code, so now the function is ready to be used:
%paste
命令既输⼊了多⾏代码⼜执⾏了它们,因此 donothing 函数已经可以使⽤了:
In [4]: donothing(10)
Out[4]: 10
A command with a similar intent is %cpaste , which opens up an interactive multiline prompt in which you can paste
one or more chunks of code to be executed in a batch:
还有⼀个魔术命令 %cpaste 也是类似的作⽤,它会打开⼀个交互的多⾏提⽰符,允许你粘贴多个代码块然后批量执⾏它们:
In [5]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
:>>> def donothing(x):
:...
return x
:-These magic commands, like others we'll see, make available functionality that would be difficult or impossible in a
standard Python interpreter.
这些魔术命令,还有我们⻢上会看到的其他命令,提供了标准Python解释器很难或⽆法提供的功能。
Running External Code: %run
执⾏外部代码: %run
As you begin developing more extensive code, you will likely find yourself working in both IPython for interactive
exploration, as well as a text editor to store code that you want to reuse. Rather than running this code in a new window,
it can be convenient to run it within your IPython session. This can be done with the %run magic.
当你使⽤Python开发更多代码之后,你会发现你可能需要两个环境,在IPython中交互式的进⾏探索和快速验证,使⽤⽂本编辑器保存那些
以后你需要重⽤的代码。当你需要在IPython中运⾏你已经保存好的Python代码⽂件时,你不需要打开⼀个新的进程执⾏它们,也不需要将
它们的代码粘贴进来,你可以使⽤ %run 魔术。
For example, imagine you've created a myscript.py file with the following contents:
例如,你创建了⼀个 myscript.py ⽂件,⾥⾯的内容是:
#------------------------------------# file: myscript.py
def square(x):
"""square a number"""
return x ** 2
for N in range(1, 4):
print(N, "squared is", square(N))
You can execute this from your IPython session as follows:
你可以在你的IPython shell中这样执⾏这个Python代码⽂件:
In [6]: %run myscript.py
1 squared is 1
2 squared is 4
3 squared is 9
Note also that after you've run this script, any functions defined within it are available for use in your IPython session:
你应该注意到了,当你执⾏完这个脚本⽂件之后,任何定义了的函数也可以在你当前的IPython会话中使⽤了。
In [7]: square(5)
Out[7]: 25
There are several options to fine-tune how your code is run; you can see the documentation in the normal way, by typing
%run? in the IPython interpreter.
还有⼀些参数可以精细控制你的代码⽂件如何执⾏;你可以像之前介绍的那样查看它的⽂档,只需要在IPython shell中输⼊ %run? 即可。
Timing Code Execution: %timeit
代码执⾏计时: %timeit
Another example of a useful magic function is %timeit , which will automatically determine the execution time of the
single-line Python statement that follows it. For example, we may want to check the performance of a list comprehension:
下⾯介绍的魔术命令是 %timeit ,它会⾃动测试统计紧跟之后的单⾏Python语句的执⾏性能(时间)。例如我们需要测试列表解析的性
能:
In [8]: %timeit L = [n ** 2 for n in range(1000)]
1000 loops, best of 3: 325 µs per loop
The benefit of %timeit is that for short commands it will automatically perform multiple runs in order to attain more
robust results. For multi line statements, adding a second % sign will turn this into a cell magic that can handle multiple
lines of input. For example, here's the equivalent construction with a for -loop:
使⽤ %timeit 的时候,它会⾃动执⾏多次,以获取更有效的结果。对于多⾏的代码来说,增加⼀个 % 号,会将本魔术命令变成单元格模
式,因此它能测试多⾏输⼊的性能。例如,下⾯是⼀段相同功能的列表初始化,使⽤的 for 循环:
In [9]: %%timeit
...: L = []
...: for n in range(1000):
...:
L.append(n ** 2)
...:
1000 loops, best of 3: 373 µs per loop
We can immediately see that list comprehensions are about 10% faster than the equivalent for -loop construction in
this case. We'll explore %timeit and other approaches to timing and profiling code in Profiling and Timing Code.
从上⾯的结果可以看出来,使⽤列表解析能⽐使⽤ for 循环的⽅式提升10%的运⾏速度。我们将在性能测算和计时中更加详细的讨论它。
Help on Magic Functions: ? , %magic , and %lsmagic
魔术命令帮助: ? 、 %magic 和 %lsmagic
Like normal Python functions, IPython magic functions have docstrings, and this useful documentation can be accessed
in the standard manner. So, for example, to read the documentation of the %timeit magic simply type this:
就像普通的Python对象,IPython魔术命令也有docstring,这些⽂档可以按照我们之前的⽅式简单的获取到。举个例⼦,想要查
阅 %timeit 的⽂档,仅需输⼊:
In [10]: %timeit?
Documentation for other functions can be accessed similarly. To access a general description of available magic
functions, including some examples, you can type this:
其他魔术命令和⽂档也可以类似获得。要获得魔术命令的通⽤描述以及它们的例⼦,你可以输⼊:
In [11]: %magic
For a quick and simple list of all available magic functions, type this:
如果想要快速简单地列出所有可⽤的魔术命令,输⼊:
In [12]: %lsmagic
Finally, I'll mention that it is quite straightforward to define your own magic functions if you wish. We won't discuss it here,
but if you are interested, see the references listed in More IPython Resources.
最后,你可以了解⾃定义魔术命令的有关知识。但是本书不会讨论这个⽅⾯,如果读者感兴趣,请参⻅更多IPython资源。
< IPython Shell
中的键盘快捷键 | ⽬录 | 输⼊输出历史 >
Open in Colab
< IPython
魔术命令 | ⽬录 | IPython和Shell命令 >
Input and Output History
输⼊输出历史
Previously we saw that the IPython shell allows you to access previous commands with the up and down arrow keys, or
equivalently the Ctrl-p/Ctrl-n shortcuts. Additionally, in both the shell and the notebook, IPython exposes several ways to
obtain the output of previous commands, as well as string versions of the commands themselves. We'll explore those
here.
前⾯我们看到IPython shell能够让你获取到命令的历史,使⽤向上箭头或者向下箭头,或者等同的Ctrl-p/Ctrl-n快捷键。除此之外,在
IPython shell和notebook中,还提供了⼀些⽅法可以获得前⾯命令的输出结果,或者字符串形式的命令本⾝。本节将讨论它们。
IPython's In and Out Objects
的 In 和 Out 对象
IPython
By now I imagine you're quite familiar with the In [1]: / Out[1]: style prompts used by IPython. But it turns out that
these are not just pretty decoration: they give a clue as to how you can access previous inputs and outputs in your
current session. Imagine you start a session that looks like this:
阅读到这⾥,作者认为你已经相当熟悉IPython的 In [1]: / Out[1]: ⻛格的提⽰符了。但是其实这些提⽰符并不是为了美观⽽采⽤的装
饰符号:它们会给出你提⽰,让你可以获取之前的输⼊和输出。例如你启动了⼀个IPython会话:
In [1]: import math
In [2]: math.sin(2)
Out[2]: 0.9092974268256817
In [3]: math.cos(2)
Out[3]: -0.4161468365471424
We've imported the built-in math package, then computed the sine and the cosine of the number 2. These inputs and
outputs are displayed in the shell with In / Out labels, but there's more–IPython actually creates some Python
variables called In and Out that are automatically updated to reflect this history:
我们载⼊了內建的 math 包,然后计算了2的正弦和余弦值。这些输⼊和输出在IPython shell当中使⽤ In / Out 标签打印在屏幕上,但实
际上这些标签的作⽤不限于此,IPython创建了两个Python的变量名叫 In 和 Out ,在每次输⼊输出的情况下都会⾃动更新和相应:
In [4]: print(In)
['', 'import math', 'math.sin(2)', 'math.cos(2)', 'print(In)']
In [5]: Out
Out[5]: {2: 0.9092974268256817, 3: -0.4161468365471424}
The In object is a list, which keeps track of the commands in order (the first item in the list is a place-holder so that
In[1] can refer to the first command):
In
对象是⼀个列表,保存着本次IPython会话的所有输⼊命令(列表中的第⼀个元素是⼀个占位符,因此第⼀条命令是 In[1] ):
In [6]: print(In[1])
import math
The Out object is not a list but a dictionary mapping input numbers to their outputs (if any):
Out
对象是⼀个字典值,将输⼊的编号对应到它们相应的输出上⾯:
In [7]: print(Out[2])
0.9092974268256817
Note that not all operations have outputs: for example, import statements and print statements don't affect the
output. The latter may be surprising, but makes sense if you consider that print is a function that returns None ; for
brevity, any command that returns None is not added to Out .
注意并不是所有的操作都有输出:例如, import 和 print 语句就不会影响输出内容。然后再深⼊思考⼀下,你会发现, print 是⼀
个返回值为 None 的函数;简⽽⾔之,任何指令返回None都不会加⼊到 Out 当中。
Where this can be useful is if you want to interact with past results. For example, let's check the sum of sin(2) ** 2
and cos(2) ** 2 using the previously-computed results:
当你需要⽤到历史结果时,上⾯的变量就⾮常有⽤。例如,我们检查⼀下 sin(2) ** 2 加上 cos(2) ** 2 的和,可以使⽤前⾯的结
果:
In [8]: Out[2] ** 2 + Out[3] ** 2
Out[8]: 1.0
The result is 1.0 as we'd expect from the well-known trigonometric identity. In this case, using these previous results
probably is not necessary, but it can become very handy if you execute a very expensive computation and want to reuse
the result!
结果是 1.0 ,和我们了解的三⻆函数运算得到的⼀样。在这个例⼦中,使⽤历史结果并不是特别需要,但是当你前⾯进⾏了⾮常耗时的运
算的时候,重⽤这个结果是⾮常⽅便的。
Underscore Shortcuts and Previous Outputs
下划线变量和之前的输出
The standard Python shell contains just one simple shortcut for accessing previous output; the variable _ (i.e., a single
underscore) is kept updated with the previous output; this works in IPython as well:
标准的Python shell包含着⼀个简单的快捷变量⽤来获取前⼀个输出结果;变量 _ (⼀个下划线),这个变量会更新为每次前⼀条语句的
输出结果。IPython中也是可以使⽤的:
In [9]: print(_)
1.0
But IPython takes this a bit further—you can use a double underscore to access the second-to-last output, and a triple
underscore to access the third-to-last output (skipping any commands with no output):
IPython
令):
扩展了这个功能,你可以使⽤双下划线获取倒数第⼆个输出结果,使⽤三下划线获取倒数第三个输出结果(当然会跳过⽆输出的命
In [10]: print(__)
-0.4161468365471424
In [11]: print(___)
0.9092974268256817
IPython stops there: more than three underscores starts to get a bit hard to count, and at that point it's easier to refer to
the output by line number.
There is one more shortcut we should mention, however–a shorthand for Out[X] is _X (i.e., a single underscore
followed by the line number):
三个就打住了,IPython也不⽀持更多的下划线了,因为多于三个的下划线就变得⽐较难以数清楚了,在这种情况下,使⽤输⼊序号会更加
⽅便⼀些。
这⾥还有⼀个快捷⽅式需要介绍, Out[x] 的快捷写法是 _x (⼀个下划线后⾯跟着输⼊序号):
In [12]: Out[2]
Out[12]: 0.9092974268256817
In [13]: _2
Out[13]: 0.9092974268256817
Suppressing Output
取消输出
Sometimes you might wish to suppress the output of a statement (this is perhaps most common with the plotting
commands that we'll explore in Introduction to Matplotlib). Or maybe the command you're executing produces a result
that you'd prefer not like to store in your output history, perhaps so that it can be deallocated when other references are
removed. The easiest way to suppress the output of a command is to add a semicolon to the end of the line:
有时你可能希望取消⼀个语句的输出结果(这在我们使⽤绘图指令时很常⻅,我们会在Matplotlib简介中详细讨论)。或者你在执⾏的指令
会产⽣的结果,你并不希望结果被存储在输出历史中,这样的结果就能在其他引⽤被移除后⾃动释放资源。取消⼀个指令的输出结果最简
单的⽅法就是在语句最后加上⼀个分号:
In [14]: math.sin(2) + math.cos(2);
Note that the result is computed silently, and the output is neither displayed on the screen or stored in the Out
dictionary:
这⾥结果将会静默的计算出来,输出既不会打印在屏幕上,也不会保存在输出 Out 的字典中:
In [15]: 14 in Out
Out[15]: False
Related Magic Commands
相关的魔术命令
For accessing a batch of previous inputs at once, the %history magic command is very helpful. Here is how you can
print the first four inputs:
要想⼀次性获得批量的输⼊历史, %history 魔术命令是⾮常有⽤的。下⾯例⼦展⽰了如何使⽤它打印出输⼊历史中头四个指令:
In [16]: %history -n 1-4
1: import math
2: math.sin(2)
3: math.cos(2)
4: print(In)
As usual, you can type %history? for more information and a description of options available. Other similar magic
commands are %rerun (which will re-execute some portion of the command history) and %save (which saves some
set of the command history to a file). For more information, I suggest exploring these using the ? help functionality
discussed in Help and Documentation in IPython.
当然,你也可以使⽤ %history? 来查阅该魔术命令的⽂档。其他类似的魔术命令包括 %rerun (重新执⾏输⼊历史中的某部分指令)
和 %save (将输⼊历史中的某部分内容保存成⽂件)。需要更多的信息,推荐使⽤ ? 魔术符号来查阅⽂档,有关 ? 号的内容请参⻅
IPython帮助和⽂档。
< IPython
魔术命令 | ⽬录 | IPython和Shell命令 >
<
输⼊输出历史 | ⽬录 | 错误和调试 >
Open in Colab
IPython and Shell Commands
IPython
和 Shell命令
When working interactively with the standard Python interpreter, one of the frustrations is the need to switch between
multiple windows to access Python tools and system command-line tools. IPython bridges this gap, and gives you a
syntax for executing shell commands directly from within the IPython terminal. The magic happens with the exclamation
point: anything appearing after ! on a line will be executed not by the Python kernel, but by the system command-line.
当使⽤标准的Python解释器时,有⼀个让⼈感到沮丧的地⽅就是你需要在不同的窗⼝之间进⾏切换,有时你需要使⽤Python,有时你⼜需
要使⽤系统命令⾏⼯具。IPython将两者联系起来,它允许你直接在IPython终端中直接运⾏shell命令。这个魔术使⽤的是感叹号:任何出
现在 ! 之后的内容将被系统shell执⾏,⽽不是Python解释器。
The following assumes you're on a Unix-like system, such as Linux or Mac OSX. Some of the examples that follow will
fail on Windows, which uses a different type of shell by default (though with the 2016 announcement of native Bash shells
on Windows, soon this may no longer be an issue!). If you're unfamiliar with shell commands, I'd suggest reviewing the
Shell Tutorial put together by the always excellent Software Carpentry Foundation.
本节内容假定你在使⽤⼀个类Unix的系统,如Linx或者Mac OS X。下⾯的⼀些例⼦会在Windows下⾯失效,因为它使⽤的是⼀种完全不同
的shell(2016年Windows宣布将直接⽀持原⽣的Bash,很快这将不成为问题。译者注:⽬前在windows下使⽤bash还是会有很多问题,微
软的原⽣实现并不理想)。如果你对于shell命令不熟悉,作者推荐你去Shell教程去学习⼀下基础的shell命令。
Quick Introduction to the Shell
快速介绍
Shell
A full intro to using the shell/terminal/command-line is well beyond the scope of this chapter, but for the uninitiated we will
offer a quick introduction here. The shell is a way to interact textually with your computer. Ever since the mid 1980s, when
Microsoft and Apple introduced the first versions of their now ubiquitous graphical operating systems, most computer
users have interacted with their operating system through familiar clicking of menus and drag-and-drop movements. But
operating systems existed long before these graphical user interfaces, and were primarily controlled through sequences
of text input: at the prompt, the user would type a command, and the computer would do what the user told it to. Those
early prompt systems are the precursors of the shells and terminals that most active data scientists still use today.
如何使⽤shell/终端/命令⾏远远超出了本章的范围,但是对于初学者,作者还是准备了⼀个简单快速的介绍。从80年代中开始,微软和苹
果想⽤⼾推出了它们的图形界⾯,时⾄今⽇,图形化操作系统已经是⽆处不在了。⼤部分的计算机⽤⼾都是使⽤他们熟悉的菜单点击和拖
放操作来使⽤操作系统。但是实际上操作系统⽐这些图形⽤⼾界⾯出现早得多,当时都是由⽤⼾输⼊⼀系列的⽂本内容对操作系统进⾏控
制:在提⽰符下,⽤⼾敲⼊⼀个命令,然后计算机会按照⽤⼾的指⽰进⾏⼯作。这种早期的提⽰符界⾯就是shell和终端的前⾝,也是直到
今天很多数据科学家仍在使⽤的⼯具。
Someone unfamiliar with the shell might ask why you would bother with this, when many results can be accomplished by
simply clicking on icons and menus. A shell user might reply with another question: why hunt icons and click menus when
you can accomplish things much more easily by typing? While it might sound like a typical tech preference impasse,
when moving beyond basic tasks it quickly becomes clear that the shell offers much more control of advanced tasks,
though admittedly the learning curve can intimidate the average computer user.
不熟悉shell的⼈可能会问,为什么你们要这么⿇烦,为什么简单的通过点击图表和菜单就能实现的功能你们要敲命令。熟练使⽤shell的⽤
⼾可能会这样回应:为什么通过简单的键盘命令就能完成的⼯作你们要点击⿏标呢。虽然看起来这是⼀个典型的技术偏好问题,但是当你
需要完成的任务变得复杂的时候,shell确实能够提供更多的控制,哪怕shell的学习曲线会吓跑很多普通的计算机⽤⼾。
As an example, here is a sample of a Linux/OSX shell session where a user explores, creates, and modifies directories
and files on their system ( osx:~ $ is the prompt, and everything after the $ sign is the typed command; text that is
preceded by a # is meant just as description, rather than something you would actually type in):
作为⼀个例⼦,这⾥有⼀个⽤⼾在Linux/OSX系统上浏览、创建和修改⽬录以及⽂件的shell会话( osx:~ $ 是提⽰符,所有出现在 $ 后
⾯的⽂本都是⼀条命令;以 # 开始的⽂本是注释作为命令的解释,⽽不是你需要真正输⼊的内容):
osx:~ $ echo "hello world"
# 使⽤echo打印输出,类似Python中的print
hello world
打印当前⼯作⽬录
这是我们当前的⼯作⽬录
# ls = 列⽰⽬录内容
osx:~ $ pwd
/home/jake
# pwd =
#
osx:~ $ ls
notebooks projects
osx:~ $ cd projects/
# cd =
改变⽬录位置
osx:projects $ pwd
/home/jake/projects
osx:projects $ ls
datasci_book
mpld3
myproject.txt
osx:projects $ mkdir myproject
# mkdir =
创建新⽬录
osx:projects $ cd myproject/
osx:myproject $ mv ../myproject.txt ./
移动⽂件,这⾥我们将⽗⽬录中的myproject.txt
移动到当前⼯作⽬录下
# mv =
#
osx:myproject $ ls
myproject.txt
Notice that all of this is just a compact way to do familiar operations (navigating a directory structure, creating a directory,
moving a file, etc.) by typing commands rather than clicking icons and menus. Note that with just a few commands
( pwd , ls , cd , mkdir , and cp ) you can do many of the most common file operations. It's when you go beyond
these basics that the shell approach becomes really powerful.
请注意,上⾯的命令都是使⽤命令输⼊完成我们平常使⽤⿏标点击操作完成的任务(浏览⽬录结构、创建⽬录、移动⽂件等)。只需要少
量的命令输⼊( pwd 、 ls 、 cd 、 mkdir 和 cp )我们就能完成很多通⽤的⽂件操作。当你更深⼊学习shell之后,你就会发现它们⾮
常强⼤。
Shell Commands in IPython
IPython
中的 shell 命令
Any command that works at the command-line can be used in IPython by prefixing it with the ! character. For example,
the ls , pwd , and echo commands can be run as follows:
任何在命令⾏中可以使⽤的命令,也都可以在IPython中使⽤,只需要在前⾯加上 ! 号。例如, ls 、 pwd 和 echo 命令:
In [1]: !ls
myproject.txt
In [2]: !pwd
/home/jake/projects/myproject
In [3]: !echo "printing from the shell"
printing from the shell
Passing Values to and from the Shell
与 shell 之间传递值
Shell commands can not only be called from IPython, but can also be made to interact with the IPython namespace. For
example, you can save the output of any shell command to a Python list using the assignment operator:
shell
命令不但能被IPython环境中调⽤,还能与IPython的命名空间产⽣交互。例如,你可以将shell命令的输出保存成⼀个Python的列表:
In [4]: contents = !ls
In [5]: print(contents)
['myproject.txt']
In [6]: directory = !pwd
In [7]: print(directory)
['/Users/jakevdp/notebooks/tmp/myproject']
Note that these results are not returned as lists, but as a special shell return type defined in IPython:
值得注意的是,这些结果并不是返回成为普通的Python列表,⽽是⼀个IPython定义的特殊shell返回值类型:
In [8]: type(directory)
IPython.utils.text.SList
This looks and acts a lot like a Python list, but has additional functionality, such as the grep and fields methods
and the s , n , and p properties that allow you to search, filter, and display the results in convenient ways. For more
information on these, you can use IPython's built-in help features.
它看起来很像⼀个Python列表,但是还包含额外的功能,⽐⽅说 grep 和 fields ⽅法,以及 s 、 n 和 p 属性,让你能够使⽤简单⽅式
搜索,过滤和显⽰结果。如果你想获得更多信息,请使⽤IPython內建的帮助特性来查看。
Communication in the other direction–passing Python variables into the shell–is possible using the {varname} syntax:
反过来,也可以传递Python的变量给shell,通过 {变量名} 语法就可以实现:
In [9]: message = "hello from Python"
In [10]: !echo {message}
hello from Python
The curly braces contain the variable name, which is replaced by the variable's contents in the shell command.
花括号⾥⾯是变量的名称,在执⾏shell命令的时候将会被变量的值替代。
Shell-Related Magic Commands
Shell
相关魔术命令
If you play with IPython's shell commands for a while, you might notice that you cannot use !cd to navigate the
filesystem:
如果你已经在IPython中使⽤了shell命令⼀段时间了,你会发现你⽆法使⽤ !cd 来改变你的⼯作⽬录:
In [11]: !pwd
/home/jake/projects/myproject
In [12]: !cd ..
In [13]: !pwd
/home/jake/projects/myproject
The reason is that shell commands in the notebook are executed in a temporary subshell. If you'd like to change the
working directory in a more enduring way, you can use the %cd magic command:
这是因为在notebook⾥⾯shell是在⼀个⼦shell空间中执⾏的。如果你需要改变⼯作⽬录的话,你可以使⽤ %cd 魔术命令:
In [14]: %cd ..
/home/jake/projects
In fact, by default you can even use this without the % sign:
事实上,你甚⾄可以不⽤ % 号:
In [15]: cd myproject
/home/jake/projects/myproject
This is known as an automagic function, and this behavior can be toggled with the %automagic magic function.
这被称为 ⾃动魔术 ,你可以使⽤ %automagic 来切换它的开关状态。
Besides %cd , other available shell-like magic functions are %cat , %cp , %env , %ls , %man , %mkdir , %more ,
%mv , %pwd , %rm , and %rmdir , any of which can be used without the % sign if automagic is on. This makes it
so that you can almost treat the IPython prompt as if it's a normal shell:
除了 %cd 之外,其他类似shell命令的魔术命令包括 %cat 、 %cp 、 %env 、 %ls 、 %man 、 %mkdir 、 %more 、 %mv 、 %pwd 、
%rm 和 %rmdir ,这些命令在 automagic 开启时都可以不带 % 使⽤。这功能令你可以⼏乎将IPython shell当成系统的shell来使⽤了:
In [16]: mkdir tmp
In [17]: ls
myproject.txt
tmp/
In [18]: cp myproject.txt tmp/
In [19]: ls tmp
myproject.txt
In [20]: rm -r tmp
This access to the shell from within the same terminal window as your Python session means that there is a lot less
switching back and forth between interpreter and shell as you write your Python code.
能够在IPython环境中直接使⽤shell,意味着你可以不⽤来回在解释器和shell终端两个窗⼝之间进⾏切换,可以提⾼你写Python代码时候的
效率。
<
输⼊输出历史 | ⽬录 | 错误和调试 >
Open in Colab
< IPython
和Shell命令 | ⽬录 | 性能测算和计时 >
Errors and Debugging
错误和调试
Code development and data analysis always require a bit of trial and error, and IPython contains tools to streamline this
process. This section will briefly cover some options for controlling Python's exception reporting, followed by exploring
tools for debugging errors in code.
开发和数据分析通常都需要很多的试验,伴随着很多的错误,IPython包含着能够将这个过程串联起来的⼯具。这⼀章节会简要介绍Python
的异常控制,然后介绍在代码中调试的⼯具。
Controlling Exceptions: %xmode
异常控制: %xmode
Most of the time when a Python script fails, it will raise an Exception. When the interpreter hits one of these exceptions,
information about the cause of the error can be found in the traceback, which can be accessed from within Python. With
the %xmode magic function, IPython allows you to control the amount of information printed when the exception is
raised. Consider the following code:
⼤部分情况下如果Python脚本执⾏失败了,都是由于抛出了异常导致的。当解释器碰到了这些异常的时候,会将错误产⽣的原因压到当前
程序执⾏的堆栈当中,你可以通过Python的traceback访问到这些信息。使⽤ %xmode 魔术指令,IPython允许你控制异常发⽣时错误信息
的数量。看例⼦:
In [1]: def func1(a, b):
return a / b
def func2(x):
a = x
b = x - 1
return func1(a, b)
In [2]: func2(1)
--------------------------------------------------------------------------ZeroDivisionError
Traceback (most recent call last)
<ipython-input-2-7cb498ea7ed1> in <module>
----> 1 func2(1)
<ipython-input-1-586ccabd0db3> in func2(x)
5
a = x
6
b = x - 1
----> 7
return func1(a, b)
<ipython-input-1-586ccabd0db3> in func1(a, b)
1 def func1(a, b):
----> 2
return a / b
3
4 def func2(x):
5
a = x
ZeroDivisionError: division by zero
Calling func2 results in an error, and reading the printed trace lets us see exactly what happened. By default, this trace
includes several lines showing the context of each step that led to the error. Using the %xmode magic function (short for
Exception mode), we can change what information is printed.
调⽤ func2 会发⽣错误,Python解析器会使⽤默认⽅式打印出堆栈信息,通过查看这些信息,你可以检查程序发⽣了什么问题。默认情
况下,打印出来的信息会包括很多⾏,每⾏会输出函数调⽤的情况。使⽤ %xmode 魔术指令(名称是Exception mode的缩写),我们可以
修改打印的信息内容。
%xmode takes a single argument, the mode, and there are three possibilities: Plain , Context , and Verbose .
The default is Context , and gives output like that just shown before. Plain is more compact and gives less
information:
需要⼀个参数,就是输出错误的模式,有三种选择: Plain , Context 和 Verbose 。默认是 Context ,该模式下的输出
就如上⾯所⻅。 Plain 会更简短,提供更少的内容:
%xmode
In [3]: %xmode Plain
Exception reporting mode: Plain
In [4]: func2(1)
Traceback (most recent call last):
File "<ipython-input-4-7cb498ea7ed1>", line 1, in <module>
func2(1)
File "<ipython-input-1-586ccabd0db3>", line 7, in func2
return func1(a, b)
File "<ipython-input-1-586ccabd0db3>", line 2, in func1
return a / b
ZeroDivisionError: division by zero
The Verbose mode adds some extra information, including the arguments to any functions that are called:
Verbose
模式会增加⼀些额外的信息,包括每个函数调⽤时候的参数值:
In [5]: %xmode Verbose
Exception reporting mode: Verbose
In [6]: func2(1)
--------------------------------------------------------------------------ZeroDivisionError
Traceback (most recent call last)
<ipython-input-6-7cb498ea7ed1> in <module>
----> 1 func2(1)
global func2 = <function func2 at 0x7fee38cf0d08>
<ipython-input-1-586ccabd0db3> in func2(x=1)
5
a = x
6
b = x - 1
----> 7
return func1(a, b)
global func1 = <function func1 at 0x7fee38cf07b8>
a = 1
b = 0
<ipython-input-1-586ccabd0db3> in func1(a=1, b=0)
1 def func1(a, b):
----> 2
return a / b
a = 1
b = 0
3
4 def func2(x):
5
a = x
ZeroDivisionError: division by zero
This extra information can help narrow-in on why the exception is being raised. So why not use the Verbose mode all
the time? As code gets complicated, this kind of traceback can get extremely long. Depending on the context, sometimes
the brevity of Default mode is easier to work with.
这些额外的信息能帮助你迅速定位到异常发⽣的原因。那么为什么我们不⼀直使⽤ Verbose 模式呢?如果你的代码变得复杂了之后,这
种堆栈的输出会变得⼗分冗⻓。根据实际情况,有时候简短的默认模式可能更加适合查错。
Debugging: When Reading Tracebacks Is Not Enough
调试:当分析堆栈已经不⾜够了
The standard Python tool for interactive debugging is pdb , the Python debugger. This debugger lets the user step
through the code line by line in order to see what might be causing a more difficult error. The IPython-enhanced version
of this is ipdb , the IPython debugger.
标准Python解析器有⼀个交互式的调试⼯具叫做 pdb 。这个调试⼯具能让⽤⼾⼀⾏⼀⾏的执⾏代码,然后定位到更困难的错误原因。
IPython增强版的调试器叫做 ipdb 。
There are many ways to launch and use both these debuggers; we won't cover them fully here. Refer to the online
documentation of these two utilities to learn more.
实际上存在着很多种⽅法来启动和使⽤这两个调试器;我们在这⾥不会完整的介绍它们。你可以参考这两个⼯具的在线⽂档来学习更多的
内容。
In IPython, perhaps the most convenient interface to debugging is the %debug magic command. If you call it after
hitting an exception, it will automatically open an interactive debugging prompt at the point of the exception. The ipdb
prompt lets you explore the current state of the stack, explore the available variables, and even run Python commands!
在IPython中,也许最简单的调试⽅式就是使⽤ %debug 魔术指令了。如果当你遇到⼀个异常之后调⽤它,IPython会⾃动打开⼀个交互式
的调试提⽰符,并定位在异常发⽣的地⽅。 ipdb 提⽰符允许你查看当前的堆栈信息,显⽰变量和它们的值,甚⾄执⾏Python命令。
Let's look at the most recent exception, then do some basic tasks–print the values of a and b , and type quit to quit
the debugging session:
让我们查看最近发⽣的那个异常,然后执⾏⼀些基础的指令来打印变量 a 和 b 的值,最后使⽤ quit 退出调试模式:
In [7]: %debug
> <ipython-input-1-586ccabd0db3>(2)func1()
1 def func1(a, b):
----> 2
return a / b
3
4 def func2(x):
5
a = x
ipdb> print(a)
1
ipdb> print(b)
0
ipdb> quit
The interactive debugger allows much more than this, though–we can even step up and down through the stack and
explore the values of variables there:
这个交互式的调试器允许我们做更多的操作,我们可以向上或向下浏览不同级别的堆栈,然后再查看那个层级的变量内容:
In [8]: %debug
> <ipython-input-1-586ccabd0db3>(2)func1()
1 def func1(a, b):
----> 2
return a / b
3
4 def func2(x):
5
a = x
ipdb> up
> <ipython-input-1-586ccabd0db3>(7)func2()
3
4 def func2(x):
5
a = x
6
b = x - 1
----> 7
return func1(a, b)
ipdb> print(x)
1
ipdb> up
> <ipython-input-6-7cb498ea7ed1>(1)<module>()
----> 1 func2(1)
ipdb> down
> <ipython-input-1-586ccabd0db3>(7)func2()
3
4 def func2(x):
5
a = x
6
b = x - 1
----> 7
return func1(a, b)
ipdb> quit
This allows you to quickly find out not only what caused the error, but what function calls led up to the error.
这不仅仅能够让你迅速定位问题的原因,还能让你⼀直回溯到错误最上层的函数调⽤。
If you'd like the debugger to launch automatically whenever an exception is raised, you can use the %pdb magic
function to turn on this automatic behavior:
如果你希望调试器保持打开状态,每当发⽣异常时就⾃动启动,你可以使⽤ %pdb 魔术指令,使⽤ on / off 参数就能打开或关闭调试器
的⾃动启动模式。
In [9]: %xmode Plain
%pdb on
func2(1)
Exception reporting mode: Plain
Automatic pdb calling has been turned ON
Traceback (most recent call last):
File "<ipython-input-9-f80f6b5cecf3>", line 3, in <module>
func2(1)
File "<ipython-input-1-586ccabd0db3>", line 7, in func2
return func1(a, b)
File "<ipython-input-1-586ccabd0db3>", line 2, in func1
return a / b
ZeroDivisionError: division by zero
> <ipython-input-1-586ccabd0db3>(2)func1()
1 def func1(a, b):
----> 2
return a / b
3
4 def func2(x):
5
a = x
ipdb> print(b)
0
ipdb> quit
Finally, if you have a script that you'd like to run from the beginning in interactive mode, you can run it with the command
%run -d , and use the next command to step through the lines of code interactively.
最后,如果你有⼀个Python脚本⽂件,然后希望在IPython中交互式运⾏,并且打开调试器的话,你可以使⽤ %run -d 魔术指令来执⾏这
个脚本,然后你还能在调试模式提⽰符下使⽤ next 命令来单步执⾏脚本中的代码。
Partial list of debugging commands
调试命令部分列表
There are many more available commands for interactive debugging than we've listed here; the following table contains a
description of some of the more common and useful ones:
Command
Description
list
Show the current location in the file
h(elp)
Show a list of commands, or find help on a specific command
q(uit)
Quit the debugger and the program
c(ontinue)
Quit the debugger, continue in the program
n(ext)
Go to the next step of the program
<enter>
Repeat the previous command
p(rint)
Print variables
s(tep)
Step into a subroutine
r(eturn)
Return out of a subroutine
除了下⾯列出来的最常⽤的命令和简单解释之外,还有很多由于篇幅原因未列出说明的调试命令。
调试命令
描述
list
显⽰当前在⽂件中的位置信息
h(elp) 查看帮助⽂档,可以显⽰列表,或查看某个命令的具体帮助信息
q(uit)
退出调试模式提⽰符
c(ontinue)
退出调试模式,继续执⾏代码
n(ext)
执⾏下⼀⾏代码,单步调试
<enter>
直接重复执⾏上⼀条命令
p(rint)
打印变量内容
s(tep)
跟踪进⼊⼦函数内部进⾏调试
r(eturn)
直接执⾏到函数返回
For more information, use the help command in the debugger, or take a look at ipdb 's online documentation.
需要了解更多信息,可以在调试器模式下使⽤ help 命令,或者参⻅ ipdb 的在线⽂档。
< IPython
和Shell命令 | ⽬录 | 性能测算和计时 >
<
错误和调试 | ⽬录 | 更多IPython资源 >
Open in Colab
Profiling and Timing Code
性能测算和计时
In the process of developing code and creating data processing pipelines, there are often trade-offs you can make
between various implementations. Early in developing your algorithm, it can be counterproductive to worry about such
things. As Donald Knuth famously quipped, "We should forget about small efficiencies, say about 97% of the time:
premature optimization is the root of all evil."
在开发阶段以及创建数据处理任务流时,经常都会出现多种可能的实现⽅案,每种都有各⾃优缺点,你需要在这之中进⾏权衡。在开发你
的算法的早期阶段,过于关注性能很可能会影响你的实现效率。正如⾼德纳(译者注:Donald Knuth,《计算机程序设计艺术》作者,最
年轻的ACM图灵奖获得者,计算机算法泰⼭北⽃)的名⾔:“我们应该忘掉那些⼩的效率问题,在绝⼤部分情况下:过早的优化是所有罪恶
之源。”
But once you have your code working, it can be useful to dig into its efficiency a bit. Sometimes it's useful to check the
execution time of a given command or set of commands; other times it's useful to dig into a multiline process and
determine where the bottleneck lies in some complicated series of operations. IPython provides access to a wide array of
functionality for this kind of timing and profiling of code. Here we'll discuss the following IPython magic commands:
%time : Time the execution of a single statement
%timeit : Time repeated execution of a single statement for more accuracy
%prun : Run code with the profiler
%lprun : Run code with the line-by-line profiler
%memit : Measure the memory use of a single statement
%mprun : Run code with the line-by-line memory profiler
但是,⼀旦你的代码已经开始⼯作了,那么你就应该开始深⼊的考虑⼀下性能问题了。有时你会需要检查⼀⾏代码或者⼀系列代码的执⾏
时间;有时你⼜需要对多个线程进⾏研究,找到⼀系列复杂操作当中的瓶颈所在。IPython提供了这类计时或性能测算的丰富功能。本章节
中我们会讨论下述的IPython魔术指令:
%time : 测量单条语句的执⾏时间
%timeit : 对单条语句进⾏多次重复执⾏,并测量平均执⾏时间,以获得更加准确的结果
%prun : 执⾏代码,并使⽤性能测算⼯具进⾏测算
%lprun : 执⾏代码,并使⽤单条语句性能测算⼯具进⾏测算
%memit : 测量单条语句的内存占⽤情况
%mprun : 执⾏代码,并使⽤单条语句内存测算⼯具进⾏测算
The last four commands are not bundled with IPython–you'll need to get the line_profiler and
memory_profiler extensions, which we will discuss in the following sections.
后⾯四个指令并不是随着IPython⼀起安装的,你需要去获取安装 line_profiler 和 memory_profiler 扩展,我们会在下⾯⼩节中介
绍。
Timing Code Snippets: %timeit and %time
代码计时⼯具: %timeit 和 %time
We saw the %timeit line-magic and %%timeit cell-magic in the introduction to magic functions in IPython Magic
Commands; it can be used to time the repeated execution of snippets of code:
我们在IPython魔术命令中已经介绍过 %timeit ⾏魔术指令和 %%timeit 块魔术指令;它们⽤来对于代码(块)进⾏重复执⾏,并测量
执⾏时间:
In [1]: %timeit sum(range(100))
737 ns ± 26.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Note that because this operation is so fast, %timeit automatically does a large number of repetitions. For slower
commands, %timeit will automatically adjust and perform fewer repetitions:
这⾥说明⼀下,因为这个操作是⾮常快速的,因此 %timeit ⾃动做了很多次的重复执⾏。如果换成⼀个执⾏慢的操作, %timeit 会⾃
动调整(减少)重复次数。
In [2]: %%timeit
total = 0
for i in range(1000):
for j in range(1000):
total += i * (-1) ** j
238 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Sometimes repeating an operation is not the best option. For example, if we have a list that we'd like to sort, we might be
misled by a repeated operation. Sorting a pre-sorted list is much faster than sorting an unsorted list, so the repetition will
skew the result:
值得注意的是,有些情况下,重复多次执⾏反⽽会得出⼀个错误的测量数据。例如,我们有⼀个列表,希望对它进⾏排序,重复执⾏的结
果会明显的误导我们。因为对⼀个已经排好序的列表执⾏排序是⾮常快的,因此在第⼀次执⾏完成之后,后⾯重复进⾏排序的测量数据都
是错误的:
In [3]: import random
L = [random.random() for i in range(100000)]
%timeit L.sort()
766 µs ± 182 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
For this, the %time magic function may be a better choice. It also is a good choice for longer-running commands, when
short, system-related delays are unlikely to affect the result. Let's time the sorting of an unsorted and a presorted list:
在这种情况下, %time 魔术指令可能会是⼀个更好的选择。对于⼀个执⾏时间较⻓的操作来说,它也更加适⽤,因为与系统相关的那些
持续时间很短的延迟将不会对结果产⽣什么影响。让我们对⼀个未排序和⼀个已排序的列表进⾏排序,并观察执⾏时间:
In [4]: import random
L = [random.random() for i in range(100000)]
print("sorting an unsorted list:")
%time L.sort()
sorting an unsorted list:
CPU times: user 29.7 ms, sys: 9 µs, total: 29.7 ms
Wall time: 29.5 ms
In [5]: print("sorting an already sorted list:")
%time L.sort()
sorting an already sorted list:
CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 4.01 ms
Notice how much faster the presorted list is to sort, but notice also how much longer the timing takes with %time versus
%timeit , even for the presorted list! This is a result of the fact that %timeit does some clever things under the hood
to prevent system calls from interfering with the timing. For example, it prevents cleanup of unused Python objects
(known as garbage collection) which might otherwise affect the timing. For this reason, %timeit results are usually
noticeably faster than %time results.
你应该⾸先注意到的是对于未排序的列表和对于已排序的列表进⾏排序的执⾏时间差别(译者注:在我的笔记本上,接近5倍的时间)。⽽
且你还需要了解 %time 和 %timeit 执⾏的区别,即使都是使⽤已经排好序的列表的情况下。这是因为 %timeit 会使⽤⼀种额外的机制
来防⽌系统调⽤影响计时的结果。例如,它会阻⽌Python解析器清理不再使⽤的对象(也被称为垃圾收集),否则垃圾收集会影响计时的
结果。因此,我们认为通常情况下 %timeit 的结果都会⽐ %time 的结果要快。
For %time as with %timeit , using the double-percent-sign cell magic syntax allows timing of multiline scripts:
对于 %time 和 %timeit 指令,使⽤两个百分号可以对⼀段代码进⾏计时:
In [6]: %%time
total = 0
for i in range(1000):
for j in range(1000):
total += i * (-1) ** j
CPU times: user 334 ms, sys: 0 ns, total: 334 ms
Wall time: 333 ms
For more information on %time and %timeit , as well as their available options, use the IPython help functionality
(i.e., type %time? at the IPython prompt).
更多关于 %time 和 %timeit 的资料,包括它们的选项,可以使⽤IPython的帮助功能(如在IPython提⽰符下键⼊ %time? )进⾏查
看。
Profiling Full Scripts: %prun
脚本代码块性能测算: %prun
A program is made of many single statements, and sometimes timing these statements in context is more important than
timing them on their own. Python contains a built-in code profiler (which you can read about in the Python
documentation), but IPython offers a much more convenient way to use this profiler, in the form of the magic function
%prun .
⼀个程序都是有很多条代码组成的,有的时候对整段代码块性能进⾏测算⽐对每条代码进⾏计时要更加重要。Python⾃带⼀个內建的代码
性能测算⼯具(你可以在Python⽂档中找到它),⽽IPython提供了⼀个更加简便的⽅式来使⽤这个测算⼯具,使⽤ %prun 魔术指令。
By way of example, we'll define a simple function that does some calculations:
我们定义⼀个简单的函数作为例⼦:
In [7]: def sum_of_lists(N):
total = 0
for i in range(5):
L = [j ^ (j >> i) for j in range(N)]
total += sum(L)
return total
Now we can call %prun with a function call to see the profiled results:
然后我们就可以使⽤ %prun 来调⽤这个函数,并查看测算的结果:
In [8]: %prun sum_of_lists(1000000)
In the notebook, the output is printed to the pager, and looks something like this:
14 function calls in 0.714 seconds
Ordered by: internal time
ncalls
5
5
1
1
1
tottime
0.599
0.064
0.036
0.014
0.000
percall
0.120
0.013
0.036
0.014
0.000
cumtime
0.599
0.064
0.699
0.714
0.714
percall filename:lineno(function)
0.120 <ipython-input-19>:4(<listcomp>)
0.013 {built-in method sum}
0.699 <ipython-input-19>:1(sum_of_lists)
0.714 <string>:1(<module>)
0.714 {built-in method exec}
在译者的笔记本上,这个指令的结果输出如下:
14 function calls in 0.500 seconds
Ordered by: internal time
ncalls
5
5
1
1
1
1
tottime
0.440
0.027
0.025
0.008
0.000
0.000
percall
0.088
0.005
0.025
0.008
0.000
0.000
cumtime
0.440
0.027
0.492
0.500
0.500
0.000
percall filename:lineno(function)
0.088 <ipython-input-8-f105717832a2>:4(<listcomp>)
0.005 {built-in method builtins.sum}
0.492 <ipython-input-8-f105717832a2>:1(sum_of_lists)
0.500 <string>:1(<module>)
0.500 {built-in method builtins.exec}
0.000 {method 'disable' of '_lsprof.Profiler' objects}
The result is a table that indicates, in order of total time on each function call, where the execution is spending the most
time. In this case, the bulk of execution time is in the list comprehension inside sum_of_lists . From here, we could
start thinking about what changes we might make to improve the performance in the algorithm.
这个结果的表格,使⽤的是每个函数调⽤执⾏总时间进⾏排序(从⼤到⼩)。从上⾯的结果可以看出,绝⼤部分的执⾏时间都发⽣在函数
sum_of_lists 中的列表解析之上。然后,我们就可以知道如果需要优化这段代码的性能,可以从哪个⽅⾯开始着⼿了。
For more information on %prun , as well as its available options, use the IPython help functionality (i.e., type %prun?
at the IPython prompt).
更多关于 %prun 的资料,包括它的选项,可以使⽤IPython的帮助功能(在IPython提⽰符下键⼊ %prun? )进⾏查看。
Line-By-Line Profiling with %lprun
使⽤ %lprun 对单条代码执⾏性能进⾏测算
The function-by-function profiling of %prun is useful, but sometimes it's more convenient to have a line-by-line profile
report. This is not built into Python or IPython, but there is a line_profiler package available for installation that can
do this. Start by using Python's packaging tool, pip , to install the line_profiler package:
刚才介绍的对于整个函数进⾏测算的 %prun 很有⽤,但是有时能对单条代码进⾏性能测算会更加⽅便我们调优。这个功能不是内置在
Python或者IPython⾥的,你需要安装⼀个第三⽅包 line_profiler 来完成这项任务。使⽤Python包管理⼯具 pip 可以很容易地安装
line_profiler 包:
$ pip install line_profiler
Next, you can use IPython to load the line_profiler IPython extension, offered as part of this package:
然后,你可以使⽤IPython来载⼊ line_profiler 扩展模块:
In [9]: %load_ext line_profiler
Now the %lprun command will do a line-by-line profiling of any function–in this case, we need to tell it explicitly which
functions we're interested in profiling:
然后 %lprun 魔术指令就可以对任何函数进⾏单⾏的性能测算了,我们需要明确指出要对哪个函数进⾏性能测算:
In [10]: %lprun -f sum_of_lists sum_of_lists(5000)
As before, the notebook sends the result to the pager, but it looks something like this:
Timer unit: 1e-06 s
Total time: 0.009382 s
File: <ipython-input-19-fa2be176cc3e>
Function: sum_of_lists at line 1
Line #
Hits
Time Per Hit
% Time Line Contents
==============================================================
1
def sum_of_lists(N):
2
1
2
2.0
0.0
total = 0
3
6
8
1.3
0.1
for i in range(5):
4
5
9001
1800.2
95.9
L = [j ^ (j >> i) for j in range(N)]
5
5
371
74.2
4.0
total += sum(L)
6
1
0
0.0
0.0
return total
像刚才⼀样,notebook会在⼀个弹出⻚⾯中展⽰结果,在译者的笔记本上执⾏效果如下:
Timer unit: 1e-06 s
Total time: 0.007372 s
File: <ipython-input-7-f105717832a2>
Function: sum_of_lists at line 1
Line #
Hits
Time Per Hit
% Time Line Contents
==============================================================
1
def sum_of_lists(N):
2
1
2.0
2.0
0.0
total = 0
3
6
9.0
1.5
0.1
for i in range(5):
4
5
7114.0
1422.8
96.5
L = [j ^ (j >> i) for j in range(N)]
5
5
246.0
49.2
3.3
total += sum(L)
6
1
1.0
1.0
0.0
return total
The information at the top gives us the key to reading the results: the time is reported in microseconds and we can see
where the program is spending the most time. At this point, we may be able to use this information to modify aspects of
the script and make it perform better for our desired use case.
结果第⼀⾏给我们提供了下⾯表中的时间单位:微秒,我们可以从中看到函数中哪⼀⾏执⾏花了最多时间。然后,我们就可以根据这些信
息对我们的代码进⾏调优,以达到我们需要的性能指标。
For more information on %lprun , as well as its available options, use the IPython help functionality (i.e., type
%lprun? at the IPython prompt).
更多关于 %lprun 的资料,包括它的选项,可以使⽤IPython的帮助功能(在IPython提⽰符下键⼊ %lprun? )进⾏查看。
Profiling Memory Use: %memit and %mprun
测算内存使⽤: %memit 和 %mprun
Another aspect of profiling is the amount of memory an operation uses. This can be evaluated with another IPython
extension, the memory_profiler . As with the line_profiler , we start by pip -installing the extension:
对于性能测算来说,还有⼀个⽅⾯需要我们注意的是操作使⽤的内存⼤⼩。这需要⽤到另外⼀个IPython的扩展模块
memory_profiler 。就像 line_profiler 那样,我们可以使⽤ pip 安装这个扩展模块:
$ pip install memory_profiler
Then we can use IPython to load the extension:
然后将扩展模块加载到IPython中:
In [11]: %load_ext memory_profiler
The memory profiler extension contains two useful magic functions: the %memit magic (which offers a memorymeasuring equivalent of %timeit ) and the %mprun function (which offers a memory-measuring equivalent of
%lprun ). The %memit function can be used rather simply:
内存性能测算⼯具 memory_profiler 包括两个有⽤的魔术指令: %memit (提供了与 %timeit 等同的内存测算功能)
和 %mprun (提供了与 %lprun 等同的内存测算功能)。 %memit 的⽤法⾮常简单:
In [12]: %memit sum_of_lists(1000000)
peak memory: 125.05 MiB, increment: 72.98 MiB
We see that this function uses about 100 MB of memory.
我们可以看到这个函数使⽤了约100MB的内存。
For a line-by-line description of memory use, we can use the %mprun magic. Unfortunately, this magic works only for
functions defined in separate modules rather than the notebook itself, so we'll start by using the %%file magic to
create a simple module called mprun_demo.py , which contains our sum_of_lists function, with one addition that
will make our memory profiling results more clear:
对于单⾏代码的内存使⽤测算,我们可以使⽤ %mprun 魔术指令。不幸的是,这个魔术指令只能应⽤在独⽴模块⾥⾯的函数上,⽽不能应
⽤在notebook本⾝。因此我们需要使⽤ %%file 魔术指令来创建⼀个简单的模块,模块的名称为 mprun_demo.py ,该模块定义了前⾯
的 sum_of_lists 函数,在这个例⼦中,我们加了⼀⾏代码,来让我们的内存测算结果更加的明显:
In [13]: %%file mprun_demo.py
def sum_of_lists(N):
total = 0
for i in range(5):
L = [j ^ (j >> i) for j in range(N)]
total += sum(L)
del L #
L
return total
将列表 的引⽤删除
Overwriting mprun_demo.py
We can now import the new version of this function and run the memory line profiler:
下⾯我们可以载⼊这个模块,然后使⽤内存测算⼯具对改写后的函数进⾏单条代码的内存性能测算:
In [14]: from mprun_demo import sum_of_lists
%mprun -f sum_of_lists sum_of_lists(1000000)
The result, printed to the pager, gives us a summary of the memory use of the function, and looks something like this:
在弹出⻚⾯中展⽰的结果给我们⼤概描述了函数中每⾏代码内存的使⽤情况,在译者笔记本上结果如下:
Filename: ./mprun_demo.py
Line #
Mem usage
Increment
Line Contents
================================================
4
71.9 MiB
0.0 MiB
L = [j ^ (j >> i) for j in range(N)]
Filename: ./mprun_demo.py
Line #
Mem usage
Increment
Line Contents
================================================
1
39.0 MiB
0.0 MiB
def sum_of_lists(N):
2
39.0 MiB
0.0 MiB
total = 0
3
46.5 MiB
7.5 MiB
for i in range(5):
4
71.9 MiB
25.4 MiB
L = [j ^ (j >> i) for j in range(N)]
5
71.9 MiB
0.0 MiB
total += sum(L)
6
46.5 MiB
-25.4 MiB
del L # remove reference to L
7
39.1 MiB
-7.4 MiB
return total
Here the Increment column tells us how much each line affects the total memory budget: observe that when we
create and delete the list L , we are adding about 25 MB of memory usage. This is on top of the background memory
usage from the Python interpreter itself.
这⾥的 Increment 列告诉我们函数的每⼀⾏怎样影响到了总内存的使⽤量:观察⼀下当我们使⽤列表解析创建 L 和使⽤ del 删除 L 时
发⽣的情况,这⾥会有⼤约25MB内存的使⽤变化。这是在Python解析器本⾝占⽤的基本内存基础上我们函数使⽤到的内存⽤量。
For more information on %memit and %mprun , as well as their available options, use the IPython help functionality
(i.e., type %memit? at the IPython prompt).
更多关于 %memit 和 mprun 的资料,包括它们的选项,可以使⽤IPython的帮助功能(在IPython提⽰符下键⼊ %memit?
或 %mprun? )进⾏查看。
<
错误和调试 | ⽬录 | 更多IPython资源 >
Open in Colab
<
性能测算和计时 | ⽬录 | Numpy介绍 >
More IPython Resources
更多IPython资源
In this chapter, we've just scratched the surface of using IPython to enable data science tasks. Much more information is
available both in print and on the Web, and here we'll list some other resources that you may find helpful.
本章中我们初步讨论了使⽤IPython来解决数据科学任务的⼀些基本内容。更多的内容可以在⽹上或书籍中找到,最后本⼩节来列出其中可
能对你有帮助的⼀些资源。
Web Resources
⽹络资源
The IPython website: The IPython website links to documentation, examples, tutorials, and a variety of other
resources.
The nbviewer website: This site shows static renderings of any IPython notebook available on the internet. The front
page features some example notebooks that you can browse to see what other folks are using IPython for!
A gallery of interesting Jupyter Notebooks: This ever-growing list of notebooks, powered by nbviewer, shows the
depth and breadth of numerical analysis you can do with IPython. It includes everything from short examples and
tutorials to full-blown courses and books composed in the notebook format!
Video Tutorials: searching the Internet, you will find any video-recorded tutorials on IPython. I'd especially
recommend seeking tutorials from the PyCon, SciPy, and PyData conferenes by Fernando Perez and Brian Granger,
two of the primary creators and maintainers of IPython and Jupyter.
官⽹: 在线⽂档、例⼦、教程和其他许多资源。
官⽹: nbviewer⽹站能展⽰互联⽹上的IPython notebook的资源⽂件。⾸⻚展⽰了⼀些notebooks的例⼦,你可以看到其他⼈
是怎样使⽤IPython的。
有趣的Jupyter notebooks展览馆: 这是⼀个不断增加的notebooks列表,由nbviewer进⾏维护,展⽰了许多既有深度⼜有⼴度的
IPython在数值分析中的应⽤。它应有尽有,从简短的例⼦,到稍⻓的教程,直⾄完整的课程和书籍,都是使⽤notebook格式。
视频教程:在互联⽹上可以搜索到很多关于IPython的视频教程。作者特别推荐PyCon,SciPy和PyData学术会上Fernando Perez 和
Brian Granger 做的报告,他们是IPython和Jupyter的主要创始⼈和维护者。
IPython
nbviewer
Books
书籍
Python for Data Analysis: Wes McKinney's book includes a chapter that covers using IPython as a data scientist.
Although much of the material overlaps what we've discussed here, another perspective is always helpful.
Learning IPython for Interactive Computing and Data Visualization: This short book by Cyrille Rossant offers a good
introduction to using IPython for data analysis.
IPython Interactive Computing and Visualization Cookbook: Also by Cyrille Rossant, this book is a longer and more
advanced treatment of using IPython for data science. Despite its name, it's not just about IPython–it also goes into
some depth on a broad range of data science topics.
作者:Wes McKinney,其中有⼀章专⻔讲述使⽤IPython来进⾏数据科学处理。虽然⼤部分的内容可能与本
书我们将要看到的有重复,从另⼀个⻆度进⾏认知永远不是坏事。
Learning IPython for Interactive Computing and Data Visualization: 作者:Cyrille Rossant,⼀本很简短的书籍专⻔介绍使⽤IPython
进⾏数据分析。
IPython Interactive Computing and Visualization Cookbook: 作者:Cyrille Rossant, ⼀本更加详尽的书籍,对于在数据科学领域使⽤
IPython进⾏了深⼊的介绍。虽然名字叫做IPython,实际上内容深度涵盖了数据科学的⼴泛课题。
Python for Data Analysis:
Finally, a reminder that you can find help on your own: IPython's ? -based help functionality (discussed in Help and
Documentation in IPython) can be very useful if you use it well and use it often. As you go through the examples here and
elsewhere, this can be used to familiarize yourself with all the tools that IPython has to offer.
最后还是再次提醒⼀下,当你在使⽤IPython时遇到了困难,不要忘记了IPython本⾝⾃带的帮助⼯具 ? (参⻅IPython帮助和⽂档),当你
经常使⽤它,熟练地掌握它之后,你会发现它能带给你的帮助超出你的预期。当你在本书中或其他资源处查看例⼦的时候,它能让你事半
功倍地熟悉IPython中提供的⼯具和功能。
<
性能测算和计时 | ⽬录 | Numpy介绍 >
<
更多IPython资源 | ⽬录 | 理解Python中的数据类型 >
Introduction to NumPy
NumPy
介绍
This chapter, along with chapter 3, outlines techniques for effectively loading, storing, and manipulating in-memory data
in Python. The topic is very broad: datasets can come from a wide range of sources and a wide range of formats,
including be collections of documents, collections of images, collections of sound clips, collections of numerical
measurements, or nearly anything else. Despite this apparent heterogeneity, it will help us to think of all data
fundamentally as arrays of numbers.
下⾯我们将开启新的⼀章,本章连同第三章⼀起,会介绍和讨论⾼效的装载,存储和处理Python中内存数据的技巧。这个主题⾮常⼴泛:
数据集可能来⾃⾮常不同的来源和⾮常不同的格式,包括⽂档的集合,图像的集合,声⾳⽚段的集合,数值测量的集合,甚⾄其他任何东
西的集合。尽管数据集有着超出想象的异质性,我们还是可以将所有的数据抽象成为数值组成的数组。
For example, images–particularly digital images–can be thought of as simply two-dimensional arrays of numbers
representing pixel brightness across the area. Sound clips can be thought of as one-dimensional arrays of intensity
versus time. Text can be converted in various ways into numerical representations, perhaps binary digits representing the
frequency of certain words or pairs of words. No matter what the data are, the first step in making it analyzable will be to
transform them into arrays of numbers. (We will discuss some specific examples of this process later in Feature
Engineering)
例如图像,这⾥我们特指数字图像,可以被认为是简单的⼆维数组,包含着代表这区域内每个像素亮度的数值。声⾳⽚段可以被认为是⼀
维的数组,包含着时间范围内声⾳强度的数值。⽂本可以使⽤各种⽅法转换成为数值⽅式表⽰,⽐⽅说使⽤⼆进制数字表⽰某个单词或短
语的出现频率。⽆论数据是哪种类型,我们对它们进⾏处理的时候,第⼀步总是设计将它们转换为数值。(参⻅特征⼯程)
For this reason, efficient storage and manipulation of numerical arrays is absolutely fundamental to the process of doing
data science. We'll now take a look at the specialized tools that Python has for handling such numerical arrays: the
NumPy package, and the Pandas package (discussed in Chapter 3).
因此,有效的存储和处理数值数组对于数据科学来说是最根本的能⼒。我们接下来会讨论Python中具备这样强⼤功能的特殊⼯具:NumPy
和Pandas(将在第三章讨论)。
This chapter will cover NumPy in detail. NumPy (short for Numerical Python) provides an efficient interface to store and
operate on dense data buffers. In some ways, NumPy arrays are like Python's built-in list type, but NumPy arrays
provide much more efficient storage and data operations as the arrays grow larger in size. NumPy arrays form the core of
nearly the entire ecosystem of data science tools in Python, so time spent learning to use effectively will be valuable no
matter what aspect of data science interests you.
本章会详细介绍NumPy(Numerical Python 数值Python的缩写),它提供了强⼤的接⼝供我们存储和操作⾮稀疏数据集合。在某些情况
下,NumPy的数组表现得就像Python內建的 列表 ,但是NumPy数组在存储和操作⼤量数据集合的时候提供了有效得多的功能和性能。
NumPy数组是Python的数据科学领域⼯具链的核⼼,很多其他的⼯具都是在它的基础上构建的,因此⽆论你感兴趣的是数据科学的哪个领
域,NumPy都值得你花时间进⾏钻研。
If you followed the advice outlined in the Preface and installed the Anaconda stack, you already have NumPy installed
and ready to go. If you're more the do-it-yourself type, you can go to http://www.numpy.org/ and follow the installation
instructions found there. Once you do, you can import NumPy and double-check the version:
如果你遵从这本书序⾔的内容安装的Anaconda,那么NumPy已经⾃动安装好了,你可以继续往下阅读。如果你喜欢DIY,你可以到
NumPy官⽹,然后按照提⽰⾃⾏安装。当你完成之后,你就可以在你的脚本中载⼊NumPy模块了,然后输出NumPy的版本号验证安装结
果:
In [1]: import numpy
numpy.__version__
Out[1]: '1.16.4'
For the pieces of the package discussed here, I'd recommend NumPy version 1.8 or later. By convention, you'll find that
most people in the SciPy/PyData world will import NumPy using np as an alias:
对于本书中的例⼦来说,作者推荐安装NumPy 1.8或以上版本。习惯上,⼤多数⼈都会使⽤ np 作为别名来载⼊NumPy模块:
In [2]: import numpy as np
Throughout this chapter, and indeed the rest of the book, you'll find that this is the way we will import and use NumPy.
本章以及本书后续内容,这都是我们载⼊NumPy模块的标准⽅式。
Reminder about Built In Documentation
內建帮助和⽂档
As you read through this chapter, don't forget that IPython gives you the ability to quickly explore the contents of a
package (by using the tab-completion feature), as well as the documentation of various functions (using the ? character
– Refer back to Help and Documentation in IPython).
在你阅读本章的过程中,请不要忘记了IPython提供的內建帮助⼯具 ? 以及使⽤制表符⾃动补全的功能。(参⻅:IPython帮助和⽂档。
For example, to display all the contents of the numpy namespace, you can type this:
例如,要查看numpy模块中的所有内容(属性和⽅法),你可以输⼊:
In [3]: np.<TAB>
And to display NumPy's built-in documentation, you can use this:
如果想查看numpy的內建⽂档,你可以输⼊:
In [4]: np?
More detailed documentation, along with tutorials and other resources, can be found at http://www.numpy.org.
需要更加详尽的⽂档、教程或其他资源,你可以访问NumPy官⽹。
<
更多IPython资源 | ⽬录 | 理解Python中的数据类型 >
< Numpy
介绍 | ⽬录 | Numpy数组基础 >
Open in Colab
Understanding Data Types in Python
理解Python中的数据类型
Effective data-driven science and computation requires understanding how data is stored and manipulated. This section
outlines and contrasts how arrays of data are handled in the Python language itself, and how NumPy improves on this.
Understanding this difference is fundamental to understanding much of the material throughout the rest of the book.
想要有效的掌握数据驱动科学和计算需要理解数据是如何存储和处理的。本节将描述和对⽐数组在Python语⾔中和在NumPy中是怎么处理
的,NumPy是如何优化了这部分的内容。理解这个区别是理解本书后续内容的基础。
Users of Python are often drawn-in by its ease of use, one piece of which is dynamic typing. While a statically-typed
language like C or Java requires each variable to be explicitly declared, a dynamically-typed language like Python skips
this specification. For example, in C you might specify a particular operation as follows:
的⽤⼾通常都是被它的易⽤性吸引来的,其中很重要⼀环就是动态类型。静态类型的语⾔,例如C或者Java,每个变量都需要明确
声明,⽽动态类型语⾔如Python就略过了这个部分。例如,在C中,你可能会写如下的代码⽚段:
Python
int result = 0;
for(int i=0; i<100; i++){
result += i;
}
While in Python the equivalent operation could be written this way:
但是在Python当中,等效的代码如下:
result = 0
for i in range(100):
result += i
Notice the main difference: in C, the data types of each variable are explicitly declared, while in Python the types are
dynamically inferred. This means, for example, that we can assign any kind of data to any variable:
注意其中主要的区别:在C当中,每个变量都需要显式声明,Python的类型是动态推断的。这意味着,我们可以给任何的变量赋值为任何
类型的数据,例如:
x = 4
x = "four"
Here we've switched the contents of x from an integer to a string. The same thing in C would lead (depending on
compiler settings) to a compilation error or other unintented consequences:
上⾯的例⼦中我们将 x 变量的内容从⼀个整数变成了⼀个字符串。如果你想在C语⾔中这样做,取决于不同的编译器,可能会导致⼀个编
译错误或者其他⽆法预料的结果。
int x = 4;
x = "four";
编译错误
//
This sort of flexibility is one piece that makes Python and other dynamically-typed languages convenient and easy to use.
Understanding how this works is an important piece of learning to analyze data efficiently and effectively with Python. But
what this type-flexibility also points to is the fact that Python variables are more than just their value; they also contain
extra informatiinon about the type of the value. We'll explore this more in the sections that follow.
这种灵活性提供了Python和其他动态类型语⾔在使⽤上的简易性。但是,理解这⾥⾯的⼯作原理对于在Python中⾼效准确的学习和分析数
据是⾮常重要的。Python的这种类型灵活性,实际上是付出了额外的存储代价的,变量不仅仅存储了数据本⾝,还需要存储其相应的类
型。我们会在本节接下来的部分继续讨论。
A Python Integer Is More Than Just an Integer
的整数不仅仅是⼀个整数
Python
The standard Python implementation is written in C. This means that every Python object is simply a cleverly-disguised C
structure, which contains not only its value, but other information as well. For example, when we define an integer in
Python, such as x = 10000 , x is not just a "raw" integer. It's actually a pointer to a compound C structure, which
contains several values. Looking through the Python 3.4 source code, we find that the integer (long) type definition
effectively looks like this (once the C macros are expanded):
标准的Python实现是使⽤C语⾔编写的。这意味着每个Python当中的对象都是⼀个伪装良好的C结构体,结构体内不仅仅包括它的值,还
有其他的信息。例如,当我们在Python中定义了⼀个整数,⽐⽅说 x=10000 , x 不仅仅是⼀个原始的整数,它在底层实际上是⼀个指向
复杂C结构体的指针,⾥⾯含有若⼲个字段。当你查阅Python 3.4的源代码的时候,你会发现整数(实际上是⻓整形)的定义如下(我们将
C语⾔中的宏定义展开后):
struct _longobject {
long ob_refcnt;
PyTypeObject *ob_type;
size_t ob_size;
long ob_digit[1];
};
A single integer in Python 3.4 actually contains four pieces:
ob_refcnt , a reference count that helps Python silently handle memory allocation and deallocation
ob_type , which encodes the type of the variable
ob_size , which specifies the size of the following data members
ob_digit , which contains the actual integer value that we expect the Python variable to represent.
⼀个Python的整数实际上包含四个部分:
ob_refcnt :引⽤计数器,Python⽤这个字段来进⾏内存分配和垃圾收集
ob_type :变量类型的编码内容
ob_size :表⽰下⾯的数据字段的⻓度
ob_digit :真正的整数值存储在这个字段
This means that there is some overhead in storing an integer in Python as compared to an integer in a compiled
language like C, as illustrated in the following figure:
这意味着在Python中存储⼀个整数要⽐在像C这样的编译语⾔中存储⼀个整数要有损耗,就像下图展⽰的那样:
Here PyObject_HEAD is the part of the structure containing the reference count, type code, and other pieces
mentioned before.
这⾥的 PyObject_HEAD 代表了前⾯的引⽤计数器、类型代码和数据⻓度的三个字段内容。
Notice the difference here: a C integer is essentially a label for a position in memory whose bytes encode an integer
value. A Python integer is a pointer to a position in memory containing all the Python object information, including the
bytes that contain the integer value. This extra information in the Python integer structure is what allows Python to be
coded so freely and dynamically. All this additional information in Python types comes at a cost, however, which becomes
especially apparent in structures that combine many of these objects.
再次注意⼀下这⾥的区别:C的整数就是简单⼀个内存位置,这个位置上的固定⻓度的字节可以表⽰⼀个整数;Python中的⼀个整数是⼀
个指向内存位置的指针,该内存位置包括Python需要表⽰⼀个整数的所有信息,其中最后固定⻓度的字节才真正存储这个整数。这些额外
的信息提供了Python的灵活性和易⽤性。这些Python类型需要的额外信息是有额外损失的,特别是当有⼀个集合需要存储许多这种类型的
数据时。
A Python List Is More Than Just a List
的列表不仅仅是⼀个列表
Python
Let's consider now what happens when we use a Python data structure that holds many Python objects. The standard
mutable multi-element container in Python is the list. We can create a list of integers as follows:
现在我们继续考虑当我们使⽤Python的数据结构来存储许多这样的Python对象时的情况。Python中标准的可变多元素的容器集合就是列
表。我们按如下的⽅式创建⼀个整数的列表:
In [1]: L = list(range(10))
L
Out[1]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
In [2]: type(L[0])
Out[2]: int
Or, similarly, a list of strings:
⼜或者,类似的,字符串的列表:
In [4]: L2 = [str(c) for c in L] #
L2
列表解析
Out[4]: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
In [5]: type(L2[0])
Out[5]: str
Because of Python's dynamic typing, we can even create heterogeneous lists:
因为Python是动态类型,我们甚⾄可以创建不同类型元素的列表:
In [6]: L3 = [True, "2", 3.0, 4]
[type(item) for item in L3]
Out[6]: [bool, str, float, int]
But this flexibility comes at a cost: to allow these flexible types, each item in the list must contain its own type info,
reference count, and other information–that is, each item is a complete Python object. In the special case that all
variables are of the same type, much of this information is redundant: it can be much more efficient to store data in a
fixed-type array. The difference between a dynamic-type list and a fixed-type (NumPy-style) array is illustrated in the
following figure:
这种灵活性是要付出代价的:要让列表能够容纳不同的类型,每个列表中的元素都必须带有⾃⼰的类型信息、引⽤计数器和其他的信息,
⼀句话,⾥⾯的每个元素都是⼀个完整的Python的对象。如果在所有的元素都是同⼀种类型的情况下,这⾥⾯绝⼤部分的信息都是冗余
的:如果我们能将数据存储在⼀个固定类型的数组中,显然会更加⾼效。下图展⽰了动态类型的列表和固定类型的数组(NumPy实现的)
的区别:
At the implementation level, the array essentially contains a single pointer to one contiguous block of data. The Python
list, on the other hand, contains a pointer to a block of pointers, each of which in turn points to a full Python object like the
Python integer we saw earlier. Again, the advantage of the list is flexibility: because each list element is a full structure
containing both data and type information, the list can be filled with data of any desired type. Fixed-type NumPy-style
arrays lack this flexibility, but are much more efficient for storing and manipulating data.
从底层实现上看,数组仅仅包含⼀个指针指向⼀块连续的内存空间。⽽Python列表,含有⼀个指针指向⼀块连续的指针内存空间,⾥⾯的
每个指针再指向内存中每个独⽴的Python对象,如我们前⾯看到的整数。列表的优势在于灵活:因为每个元素都是完整的Python的类型对
象结构,包含了数据和类型信息,因此列表可以存储任何类型的数据。NumPy使⽤的固定类型的数组缺少这种灵活性,但是对于存储和操
作数据会⾼效许多。
Fixed-Type Arrays in Python
的固定类型数组
Python
Python offers several different options for storing data in efficient, fixed-type data buffers. The built-in array module
(available since Python 3.3) can be used to create dense arrays of a uniform type:
提供了许多不同的选择能让你⾼效的存储数据,使⽤固定类型数据。內建的 array 模块(从Python 3.3开始提供)可以⽤来创建同
⼀类型的数组:
Python
In [7]: import array
L = list(range(10))
A = array.array('i', L)
A
Out[7]: array('i', [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Here 'i' is a type code indicating the contents are integers.
这⾥的 i 是表⽰数据类型是整数的类型代码。
Much more useful, however, is the ndarray object of the NumPy package. While Python's array object provides
efficient storage of array-based data, NumPy adds to this efficient operations on that data. We will explore these
operations in later sections; here we'll demonstrate several ways of creating a NumPy array.
更常⽤的是 ndarray 对象,由NumPy包提供。虽然Python的 array 提供了数组的⾼效存储,NumPy更加提供了数组的⾼效运算。我们
会在后续⼩节中陆续介绍这些操作;这⾥我们⾸先介绍创建NumPy数组的集中⽅式。
We'll start with the standard NumPy import, under the alias np :
当然最开始要做的是将NumPy包载⼊,惯例上提供别名 np :
In [8]: import numpy as np
Creating Arrays from Python Lists
使⽤Python列表创建数组
First, we can use np.array to create arrays from Python lists:
⾸先,我们可以使⽤ np.array 来将⼀个Python列表变成⼀个数组:
整数数组
In [9]: #
:
np.array([1, 4, 2, 5, 3])
Out[9]: array([1, 4, 2, 5, 3])
Remember that unlike Python lists, NumPy is constrained to arrays that all contain the same type. If types do not match,
NumPy will upcast if possible (here, integers are up-cast to floating point):
记住和Python列表不同,NumPy数组只能含有同⼀种类型的数据。如果类型不⼀样,NumPy会尝试向上扩展类型(下⾯例⼦中会将整数
向上扩展为浮点数):
In [10]: np.array([3.14, 4, 2, 3])
Out[10]: array([3.14, 4.
, 2.
, 3.
])
If we want to explicitly set the data type of the resulting array, we can use the dtype keyword:
如果你需要明确指定数据的类型,你可以使⽤ dtype 关键字参数:
In [11]: np.array([1, 2, 3, 4], dtype='float32')
Out[11]: array([1., 2., 3., 4.], dtype=float32)
Finally, unlike Python lists, NumPy arrays can explicitly be multi-dimensional; here's one way of initializing a
multidimensional array using a list of lists:
最后,不同于Python的列表,NumPy的数组可以明确表⽰为多维;下⾯例⼦是⼀个使⽤列表的列表来创建⼆维数组的⽅法:
In [13]: # 更准确的说,应该是⽣成器的列表,列表解析中有三个range⽣成器
# 分别是range(2, 5), range(4, 7) 和 range(6, 9)
np.array([range(i, i + 3) for i in [2, 4, 6]])
Out[13]: array([[2, 3, 4],
[4, 5, 6],
[6, 7, 8]])
The inner lists are treated as rows of the resulting two-dimensional array.
内部的列表作为⼆维数组的⾏。
Creating Arrays from Scratch
从头开始创建数组
Especially for larger arrays, it is more efficient to create arrays from scratch using routines built into NumPy. Here are
several examples:
使⽤NumPy的⽅法从头创建数组会更加⾼效,特别对于⼤型数组来说。下⾯有⼏个例⼦:
将数组元素都填充为 ,10是数组⻓度
In [14]: # zeros
0
np.zeros(10, dtype=int)
Out[14]: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
将数组元素都填充为 ,
是数组的维度说明,表明数组是⼆维的3⾏5列
In [15]: # ones
1 (3, 5)
np.ones((3, 5), dtype=float)
Out[15]: array([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
将数组元素都填充为参数值3.14,(3, 5)是数组的维度说明,表明数组是⼆维的3⾏5列
In [16]: # full
np.full((3, 5), 3.14)
Out[16]: array([[3.14, 3.14, 3.14, 3.14, 3.14],
[3.14, 3.14, 3.14, 3.14, 3.14],
[3.14, 3.14, 3.14, 3.14, 3.14]])
类似 ,创建⼀段序列值
起始值是 (包含),结束值是20(不包含),步⻓为2
In [17]: # arange
range
#
0
np.arange(0, 20, 2)
Out[17]: array([ 0,
2,
4,
6,
8, 10, 12, 14, 16, 18])
创建⼀段序列值,其中元素按照区域进⾏线性(平均)划分
起始值是 (包含),结束值是1(包含),共5个元素
In [18]: # linspace
#
0
np.linspace(0, 1, 5)
Out[18]: array([0.
, 0.25, 0.5 , 0.75, 1.
])
随机分布创建数组
,
是维度说明,⼆维数组3⾏3列
In [19]: # random.random
#
[0, 1) (3, 3)
np.random.random((3, 3))
随机值范围为
Out[19]: array([[0.28957547, 0.80872794, 0.36451325],
[0.30178461, 0.13998063, 0.21693246],
[0.81413802, 0.26299406, 0.53082583]])
正态分布创建数组
均值 ,标准差 ,
是维度说明,⼆维数组3⾏3列
In [18]: # random.normal
#
0
1 (3, 3)
np.random.normal(0, 1, (3, 3))
Out[18]: array([[ 1.51772646,
[ 0.25671348,
[ 0.68446945,
0.39614948, -0.10634696],
0.00732722, 0.37783601],
0.15926039, -0.70744073]])
随机整数创建数组,随机数范围[0, 10)
In [19]: # random.randint
np.random.randint(0, 10, (3, 3))
Out[19]: array([[2, 3, 4],
[5, 7, 8],
[0, 5, 0]])
的单位矩阵数组
In [20]: # 3x3
np.eye(3)
Out[20]: array([[ 1.,
[ 0.,
[ 0.,
0.,
1.,
0.,
0.],
0.],
1.]])
创建⼀个未初始化的数组,数组元素的值保持为原有的内存空间值
In [20]: # empty
np.empty(3)
Out[20]: array([5.74020278e+180, 4.00193173e-322, 4.66594353e-310])
NumPy Standard Data Types
标准数据类型
NumPy
NumPy arrays contain values of a single type, so it is important to have detailed knowledge of those types and their
limitations. Because NumPy is built in C, the types will be familiar to users of C, Fortran, and other related languages.
数组仅包含⼀种类型数据,因此它的类型系统和Python也有所区别,因为对于每⼀种NumPy类型,都需要更详细的类型信息和限
制。因为NumPy是使⽤C构建的,它的类型系统对于C、Fortran的⽤⼾来说不会陌⽣。
NumPy
The standard NumPy data types are listed in the following table. Note that when constructing an array, they can be
specified using a string:
标准NumPy数据类型⻅下表。正如上⾯介绍的,当我们创建数组的时候,我们可以将 dtype 参数指定为下⾯类型的字符串名称来指定数
组的数据类型。
np.zeros(10, dtype='int16')
Or using the associated NumPy object:
也可以将 dtype 指定为对应的NumPy对象:
np.zeros(10, dtype=np.int16)
Data type
Description
布尔(True 或 False) ⼀个字节
int_
默认整数类型 (类似C的 long ; 通常可以是 int64 或 int32 )
intc
类似C的 int (通常可以是 int32 或 int64 )
intp ⽤于索引值的整数(类似C的 ssize_t ; 通常可以是 int32 或 int64 )
int8
整数,1字节 (-128 〜 127)
int16
整数,2字节 (-32768 〜 32767)
int32
整数,4字节 (-2147483648 〜 2147483647)
int64
整数,8字节 (-9223372036854775808 〜 9223372036854775807)
uint8
字节 (0 〜 255)
uint16
⽆符号整数 (0 〜 65535)
uint32
⽆符号整数 (0 〜 4294967295)
uint64
⽆符号整数 (0 〜 18446744073709551615)
float_
float64 的简写
float16
半精度浮点数: 1⽐特符号位, 5⽐特指数位, 10⽐特尾数位
float32
单精度浮点数: 1⽐特符号位, 8⽐特指数位, 23⽐特尾数位
float64
双精度浮点数: 1⽐特符号位, 11⽐特指数位, 52⽐特尾数位
complex_
complex128 的简写
complex64
复数, 由2个单精度浮点数组成
complex128
复数, 由2个双精度浮点数组成
bool_
More advanced type specification is possible, such as specifying big or little endian numbers; for more information, refer
to the NumPy documentation. NumPy also supports compound data types, which will be covered in Structured Data:
NumPy's Structured Arrays.
还有更多的⾼级的类型声明,⽐如指定⼤尾或⼩尾表⽰;需要获得更多内容,请查阅NumPy在线⽂档。NumPy也⽀持复合数据类型,这部
分我们将在结构化数据:NumPy⾥的结构化数组中进⾏介绍
< Numpy
介绍 | ⽬录 | Numpy数组基础 >
Open in Colab
<
理解Python中的数据类型 | ⽬录 | 使⽤Numpy计算:通⽤函数 >
Open in Colab
The Basics of NumPy Arrays
数组基础
NumPy
Data manipulation in Python is nearly synonymous with NumPy array manipulation: even newer tools like Pandas
(Chapter 3) are built around the NumPy array. This section will present several examples of using NumPy array
manipulation to access data and subarrays, and to split, reshape, and join the arrays. While the types of operations
shown here may seem a bit dry and pedantic, they comprise the building blocks of many other examples used throughout
the book. Get to know them well!
中的数据操作基本就是NumPy数组操作的同义词:⼀些新的⼯具像Pandas(第三章)都是依赖于NumPy数组建⽴起来的。本节会
展⽰使⽤NumPy数组操作和访问数据以及⼦数组的⼀些例⼦,包括切分、变形和组合。尽管这⾥展⽰的操作有些枯燥和学术化,但是它们
是组成本书后⾯使⽤的例⼦的基础。你应该更好的掌握它们。
Python
We'll cover a few categories of basic array manipulations here:
Attributes of arrays: Determining the size, shape, memory consumption, and data types of arrays
Indexing of arrays: Getting and setting the value of individual array elements
Slicing of arrays: Getting and setting smaller subarrays within a larger array
Reshaping of arrays: Changing the shape of a given array
Joining and splitting of arrays: Combining multiple arrays into one, and splitting one array into many
我们会讨论下述数组操作的基本内容:
数组的属性: 获得数组的⼤⼩、形状、内存占⽤以及数据类型
数组索引: 获得和设置单个数组元素的值
数组切⽚: 获得和设置数组中的⼦数组
数组变形: 改变数组的形状
组合和切分数组: 将多个数组组合成⼀个,或者将⼀个数组切分成多个
NumPy Array Attributes
数组属性
NumPy
First let's discuss some useful array attributes. We'll start by defining three random arrays, a one-dimensional, twodimensional, and three-dimensional array. We'll use NumPy's random number generator, which we will seed with a set
value in order to ensure that the same random arrays are generated each time this code is run:
⾸先我们来讨论⼀些数组有⽤的属性。我们从定义三个数组开始,⼀个⼀维的,⼀个⼆维的和⼀个三维的数组。我们采⽤NumPy的随机数
产⽣器来创建数组,产⽣之前我们会给定⼀个随机种⼦,这样来保证每次代码运⾏的时候都能得到相同的数组:
In [1]: import numpy as np
np.random.seed(0) #
设定随机种⼦,保证实验的可重现
x1 = np.random.randint(10, size=6) # ⼀维数组
x2 = np.random.randint(10, size=(3, 4)) # ⼆维数组
x3 = np.random.randint(10, size=(3, 4, 5)) # 三维数组
Each array has attributes ndim (the number of dimensions), shape (the size of each dimension), and size (the
total size of the array):
每个数组都有属性 ndim ,代表数组的维度, shape 代表每个维度的⻓度(形状)和 size 代表数组的总⻓度(元素个数)
输出三维数组的维度、形状和总⻓度
In [2]: #
print("x3 ndim: ", x3.ndim)
print("x3 shape:", x3.shape)
print("x3 size: ", x3.size)
x3 ndim: 3
x3 shape: (3, 4, 5)
x3 size: 60
Another useful attribute is the dtype , the data type of the array (which we discussed previously in Understanding Data
Types in Python):
另⼀个有⽤的属性是 dtype ,数组的数据类型(我们在上⼀节理解Python的数据类型中已经⻅过)。
In [3]: print("dtype:", x3.dtype)
dtype: int64
Other attributes include itemsize , which lists the size (in bytes) of each array element, and nbytes , which lists the
total size (in bytes) of the array:
还有属性包括 itemsize 代表每个数组元素的⻓度(单位字节), nbytes 代表数组的总字节⻓度:
In [4]: print("itemsize:", x3.itemsize, "bytes")
print("nbytes:", x3.nbytes, "bytes")
itemsize: 8 bytes
nbytes: 480 bytes
In general, we expect that nbytes is equal to itemsize times size .
通常,我们可以认为 nbytes 等于 itemsize 乘以 size 。
Array Indexing: Accessing Single Elements
数组索引:获取单个元素
If you are familiar with Python's standard list indexing, indexing in NumPy will feel quite familiar. In a one-dimensional
array, the ith value (counting from zero) can be accessed by specifying the desired index in square brackets, just as with
Python lists:
如果我们熟悉Python列表的索引⽅式,那么NumPy数组的索引⽅式也是很相似的。对于⼀维数组来说,第i个元素值(从0开始)可以使⽤
中括号内的索引值获得:
In [5]: x1
Out[5]: array([5, 0, 3, 3, 7, 9])
In [6]: x1[0]
Out[6]: 5
In [7]: x1[4]
Out[7]: 7
To index from the end of the array, you can use negative indices:
需要从末尾进⾏索引取值,你可以使⽤负的索引值:
In [8]: x1[-1]
Out[8]: 9
In [9]: x1[-2]
Out[9]: 7
In a multi-dimensional array, items can be accessed using a comma-separated tuple of indices:
在多维数组中获取元素值,可以在中括号中使⽤⼀个索引值的元组:
译者注:多维数组的索引⽅式与列表的列表索引⽅式是不同的。列表的列表在Python中需要使⽤多个中括号进⾏索引,如 x[i][j] 的⽅
式。
In [10]: x2
Out[10]: array([[3, 5, 2, 4],
[7, 6, 8, 8],
[1, 6, 7, 7]])
In [11]: x2[0, 0]
Out[11]: 3
In [12]: x2[2, 0]
Out[12]: 1
In [13]: x2[2, -1]
Out[13]: 7
Values can also be modified using any of the above index notation:
元素值也可以通过上述的索引语法进⾏修改:
In [14]: x2[0, 0] = 12
x2
Out[14]: array([[12,
[ 7,
[ 1,
5,
6,
6,
2,
8,
7,
4],
8],
7]])
Keep in mind that, unlike Python lists, NumPy arrays have a fixed type. This means, for example, that if you attempt to
insert a floating-point value to an integer array, the value will be silently truncated. Don't be caught unaware by this
behavior!
请记住,与Python的列表不同,NumPy数组是固定类型的。这意味着,如果你试图将⼀个浮点数值放⼊⼀个整数型数组,这个值会被默默
地截成整数。这是⽐较容易犯的错误。
In [15]: x1[0] = 3.14159 # 会被截成整数
x1
Out[15]: array([3, 0, 3, 3, 7, 9])
Array Slicing: Accessing Subarrays
数组切⽚:获取⼦数组
Just as we can use square brackets to access individual array elements, we can also use them to access subarrays with
the slice notation, marked by the colon ( : ) character. The NumPy slicing syntax follows that of the standard Python list;
to access a slice of an array x , use this:
x[start:stop:step]
If any of these are unspecified, they default to the values start=0 , stop= size of dimension , step=1 . We'll
take a look at accessing sub-arrays in one dimension and in multiple dimensions.
正如我们可以使⽤中括号获取单个元素值,我们也可以使⽤中括号的切⽚语法获取⼦数组,切⽚的语法遵从标准Python列表的切⽚语法格
式;对于⼀个数组 x 进⾏切⽚:
x[start:stop:step]
如果三个参数没有设置值的话,默认值分别是 start=0 , stop= 维度的⻓度 , step=1 。我们来看看在⼀维数组和多维数组中进⾏
切⽚取⼦数组的例⼦。
One-dimensional subarrays
⼀维⼦数组
In [16]: x = np.arange(10)
x
Out[16]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
In [17]: x[:5]
前五个元素
#
Out[17]: array([0, 1, 2, 3, 4])
In [19]: x[5:]
从序号5开始的所有元素
#
Out[19]: array([5, 6, 7, 8, 9])
In [20]: x[4:7]
中间4~6序号的元素
#
Out[20]: array([4, 5, 6])
In [21]: x[::2]
每隔⼀个取元素
#
Out[21]: array([0, 2, 4, 6, 8])
In [22]: x[1::2]
每隔⼀个取元素,开始序号为1
#
Out[22]: array([1, 3, 5, 7, 9])
A potentially confusing case is when the step value is negative. In this case, the defaults for start and stop are
swapped. This becomes a convenient way to reverse an array:
当step为负值时,将会在数组⾥反向的取元素,这是将数组反向排序最简单的⽅法:
译者注,从其他编程语⾔转Python的初学者,很容易问⼀个问题,我想反序⼀个字符串,怎么找不到函数啊,內建的没有,str的⽅法也没
有。答案是,因为根本不需要,例如:
s = 'hello world'
#
'dlrow olleh'
print(s[::-1])
下⾯就会输出
In [23]: x[::-1]
反序数组
#
Out[23]: array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
In [24]: x[5::-2]
从序号5开始向前取元素,每隔⼀个取⼀个元素
#
Out[24]: array([5, 3, 1])
Multi-dimensional subarrays
多维⼦数组
Multi-dimensional slices work in the same way, with multiple slices separated by commas. For example:
多维数组的切⽚也⼀样,只是在中括号中使⽤逗号分隔多个切⽚声明。例如:
In [25]: x2
Out[25]: array([[12,
[ 7,
[ 1,
5,
6,
6,
In [26]: x2[:2, :3]
#
Out[26]: array([[12,
[ 7,
5,
6,
In [27]: x2[:3, ::2]
#
Out[27]: array([[12,
[ 7,
[ 1,
2],
8],
7]])
2,
8,
7,
4],
8],
7]])
⾏的维度取前两个,列的维度取前三个,形状变为(2, 3)
2],
8]])
⾏的维度取前三个(全部),列的维度每个⼀个取⼀列,形状变为(3, 2)
Finally, subarray dimensions can even be reversed together:
最后,⼦数组的各维度还可以反序:
In [28]: x2[::-1, ::-1] # ⾏和列都反序,形状保持(3, 4)
Out[28]: array([[ 7,
[ 8,
[ 4,
7,
8,
2,
6, 1],
6, 7],
5, 12]])
Accessing array rows and columns
获取数组的⾏和列
One commonly needed routine is accessing of single rows or columns of an array. This can be done by combining
indexing and slicing, using an empty slice marked by a single colon ( : ):
还有⼀种常⻅的需要是获取数组的单⾏或者单列。这可以通过组合索引和切⽚两个操作做到,使⽤⼀个不带参数的冒号 : 可以表⽰取该维
度的所有元素:
In [29]: print(x2[:, 0])
[12
7
# x2
1]
In [30]: print(x2[0, :])
[12
的第⼀列
5
2
# x2
的第⼀⾏
4]
In the case of row access, the empty slice can be omitted for a more compact syntax:
如果是获取⾏数据的话,可以省略后续的切⽚,写成更加简洁的⽅式:
In [31]: print(x2[0]) # 等同于 x2[0, :]
[12
5
2
4]
Subarrays as no-copy views
⼦数组是⾮副本的视图
One important–and extremely useful–thing to know about array slices is that they return views rather than copies of the
array data. This is one area in which NumPy array slicing differs from Python list slicing: in lists, slices will be copies.
Consider our two-dimensional array from before:
⼀个⾮常重要和有⽤的概念你需要知道的就是数组的切⽚返回的实际上是⼦数组的视图⽽不是它们的副本。这是NumPy数组的切⽚和
Python列表的切⽚的主要区别,列表的切⽚返回的是副本。⽤上⾯的⼆维做例⼦:
In [32]: print(x2)
[[12
[ 7
[ 1
5
6
6
2
8
7
4]
8]
7]]
Let's extract a 2 × 2 subarray from this:
让我们从中取⼀个
2 ×2
的⼦数组:
In [33]: x2_sub = x2[:2, :2]
print(x2_sub)
[[12
[ 7
5]
6]]
Now if we modify this subarray, we'll see that the original array is changed! Observe:
如果我们修改这个⼦数组,我们看到原来的数组也会随之更改:
In [34]: x2_sub[0, 0] = 99
print(x2_sub)
[[99
[ 7
5]
6]]
In [35]: print(x2)
[[99
[ 7
[ 1
5
6
6
2
8
7
4]
8]
7]]
This default behavior is actually quite useful: it means that when we work with large datasets, we can access and process
pieces of these datasets without the need to copy the underlying data buffer.
这个默认⾏为是很有⽤的:这意味着当我们在处理⼤数据集时,我们可以获取和处理其中的部分⼦数据集⽽不需要在内存中复制⼀份数据
的副本。
Creating copies of arrays
创建数组的副本
Despite the nice features of array views, it is sometimes useful to instead explicitly copy the data within an array or a
subarray. This can be most easily done with the copy() method:
尽管使⽤视图有上述的优点,有时候我们还是需要从数组中复制⼀份⼦数组出来。这可以使⽤ copy ⽅法简单的办到:
In [36]: x2_sub_copy = x2[:2, :2].copy()
print(x2_sub_copy)
[[99
[ 7
5]
6]]
If we now modify this subarray, the original array is not touched:
现在如果我们改变这个⼦数组,原数组会保持不变:
In [37]: x2_sub_copy[0, 0] = 42
print(x2_sub_copy)
[[42
[ 7
5]
6]]
In [38]: print(x2)
[[99
[ 7
[ 1
5
6
6
2
8
7
4]
8]
7]]
Reshaping of Arrays
改变数组的形状
Another useful type of operation is reshaping of arrays. The most flexible way of doing this is with the reshape method.
For example, if you want to put the numbers 1 through 9 in a 3 × 3 grid, you can do the following:
另⼀个数组的常⽤操作是改变形状。最⽅便的⽅式是使⽤ reshape ⽅法实现。例如,如果你希望将1~9的数放⼊⼀个
你可以这样做:
3 ×3
的数组⾥⾯,
In [39]: grid = np.arange(1, 10).reshape((3, 3))
print(grid)
[[1 2 3]
[4 5 6]
[7 8 9]]
Note that for this to work, the size of the initial array must match the size of the reshaped array. Where possible, the
reshape method will use a no-copy view of the initial array, but with non-contiguous memory buffers this is not always
the case.
注意,改变形状要能成功,原始数组和新的形状的数组的总⻓度 size 必须⼀样。当可能的情况下, reshape 会尽量使⽤原始数组的视
图,但是如果原始数组的数据存储在不连续的内存区,就会进⾏复制。
Another common reshaping pattern is the conversion of a one-dimensional array into a two-dimensional row or column
matrix. This can be done with the reshape method, or more easily done by making use of the newaxis keyword
within a slice operation:
另外⼀个常⽤的改变形状的操作就是将⼀个⼀维数组变成⼆维数组中的⼀⾏或者⼀列。这也可以使⽤ reshape ⽅法实现,或者更简单的
⽅式是使⽤切⽚语法中的 newaxis 属性增加⼀个维度:
In [40]: x = np.array([1, 2, 3])
使⽤
变为
#
reshape
(1, 3)
x.reshape((1, 3))
Out[40]: array([[1, 2, 3]])
使⽤
,增加⾏维度,形状也是 (1, 3)
In [41]: #
newaxis
x[np.newaxis, :]
Out[41]: array([[1, 2, 3]])
使⽤
变为
In [42]: #
reshape
(3, 1)
x.reshape((3, 1))
Out[42]: array([[1],
[2],
[3]])
使⽤
增加列维度,形状也是 (3, 1)
In [43]: #
newaxis
x[:, np.newaxis]
Out[43]: array([[1],
[2],
[3]])
We will see this type of transformation often throughout the remainder of the book.
我们会在本书后续的内容经常看到这样的变换。
Array Concatenation and Splitting
数组的连接和切分
All of the preceding routines worked on single arrays. It's also possible to combine multiple arrays into one, and to
conversely split a single array into multiple arrays. We'll take a look at those operations here.
前⾯的⽅法都是在单个数组上进⾏操作。我们也可以将多个数组组成⼀个,或者反过来将⼀个数组切分成多个。下⾯我们来看看这些操
作。
Concatenation of arrays
连接数组
Concatenation, or joining of two arrays in NumPy, is primarily accomplished using the routines np.concatenate ,
np.vstack , and np.hstack . np.concatenate takes a tuple or list of arrays as its first argument, as we can see
here:
在NumPy中连接或者组合多个数组,有三个不同的⽅法 np.concatenate , np.vstack 和 np.hstack 。 np.concatenate 接受
⼀个数组的元组或列表作为第⼀个参数,如下:
In [44]: x = np.array([1, 2, 3])
y = np.array([3, 2, 1])
np.concatenate([x, y])
Out[44]: array([1, 2, 3, 3, 2, 1])
You can also concatenate more than two arrays at once:
你也可以⼀次连接两个以上的数组:
In [45]: z = [99, 99, 99]
print(np.concatenate([x, y, z]))
[ 1
2
3
3
2
1 99 99 99]
It can also be used for two-dimensional arrays:
也可以⽤来连接⼆维数组:
In [46]: grid = np.array([[1, 2, 3],
[4, 5, 6]])
沿着第⼀个维度进⾏连接,即按照⾏连接,axis=0
In [47]: #
np.concatenate([grid, grid])
Out[47]: array([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
沿着第⼆个维度进⾏连接,即按照列连接,
In [48]: #
axis=1
np.concatenate([grid, grid], axis=1)
Out[48]: array([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
For working with arrays of mixed dimensions, it can be clearer to use the np.vstack (vertical stack) and np.hstack
(horizontal stack) functions:
进⾏连接的数组如果具有不同的维度,使⽤ np.vstack (垂直堆叠)和 np.hstack (⽔平堆叠)会更加清晰:
In [49]: x = np.array([1, 2, 3])
grid = np.array([[9, 8, 7],
[6, 5, 4]])
沿着垂直⽅向进⾏堆叠
#
np.vstack([x, grid])
Out[49]: array([[1, 2, 3],
[9, 8, 7],
[6, 5, 4]])
沿着⽔平⽅向进⾏堆叠
In [50]: #
y = np.array([[99],
[99]])
np.hstack([grid, y])
Out[50]: array([[ 9,
[ 6,
8,
5,
7, 99],
4, 99]])
Similary, np.dstack will stack arrays along the third axis.
类似的, np.dstack 会沿着第三个维度(深度)进⾏堆叠。
Splitting of arrays
切分数组
The opposite of concatenation is splitting, which is implemented by the functions np.split , np.hsplit , and
np.vsplit . For each of these, we can pass a list of indices giving the split points:
连接的反操作是切分,主要的⽅法包括 np.split , np.hsplit 和 np.vsplit 。我们可以传递⼀个列表参数表⽰切分的点:
In [51]: x = [1, 2, 3, 99, 99, 3, 2, 1]
x1, x2, x3 = np.split(x, [3, 5]) #
print(x1, x2, x3)
在序号3和序号5处进⾏切分,返回三个数组
[1 2 3] [99 99] [3 2 1]
Notice that N split-points, leads to N + 1 subarrays. The related functions np.hsplit and np.vsplit are similar:
你应该已经发现N个切分点会返回N+1个⼦数组。相应的 np.hsplit 和 np.vsplit 也是相似的:
In [52]: grid = np.arange(16).reshape((4, 4))
grid
Out[52]: array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
In [54]: upper, lower = np.vsplit(grid, [2]) #
print(upper)
print(lower)
沿垂直⽅向切分,切分点⾏序号为2
[[0 1 2 3]
[4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]
In [55]: left, right = np.hsplit(grid, [2]) #
print(left)
print(right)
沿⽔平⽅向切分数组,切分点列序号为2
[[ 0 1]
[ 4 5]
[ 8 9]
[12 13]]
[[ 2 3]
[ 6 7]
[10 11]
[14 15]]
Similarly, np.dsplit will split arrays along the third axis.
同样 np.dsplit 会沿着第三个维度切分数组。
<
理解Python中的数据类型 | ⽬录 | 使⽤Numpy计算:通⽤函数 >
Open in Colab
< Numpy
数组基础 | ⽬录 | 聚合:Min, Max, 以及其他 >
Open in Colab
Computation on NumPy Arrays: Universal Functions
数组运算:通⽤函数
NumPy
Up until now, we have been discussing some of the basic nuts and bolts of NumPy; in the next few sections, we will dive
into the reasons that NumPy is so important in the Python data science world. Namely, it provides an easy and flexible
interface to optimized computation with arrays of data.
直到⽬前为⽌,我们已经讨论了⼀些NumPy的基本构件;在下⾯⼏个⼩节中,我们会深⼊讨论NumPy能在Python数据科学中占据重要地
位的原因。简⽽⾔之,NumPy提供了简单和灵活的接⼝来对数组数据计算进⾏优化。
Computation on NumPy arrays can be very fast, or it can be very slow. The key to making it fast is to use vectorized
operations, generally implemented through NumPy's universal functions (ufuncs). This section motivates the need for
NumPy's ufuncs, which can be used to make repeated calculations on array elements much more efficient. It then
introduces many of the most common and useful arithmetic ufuncs available in the NumPy package.
对NumPy的数组进⾏计算相较其他普通的实现⽅式⽽⾔是⾮常快的。快的原因关键在于使⽤了向量化的操作,因为它们都是通过NumPy的
通⽤函数(ufuncs)实现的。希望通过本节的介绍,能让读者习惯使⽤ufuncs,它们能使在数组元素上的重复计算更加快速和⾼效。本节
还会介绍许多NumPy中最常⽤的ufuncs数学计算⽅法。
The Slowness of Loops
循环,慢的实现
Python's default implementation (known as CPython) does some operations very slowly. This is in part due to the
dynamic, interpreted nature of the language: the fact that types are flexible, so that sequences of operations cannot be
compiled down to efficient machine code as in languages like C and Fortran. Recently there have been various attempts
to address this weakness: well-known examples are the PyPy project, a just-in-time compiled implementation of Python;
the Cython project, which converts Python code to compilable C code; and the Numba project, which converts snippets of
Python code to fast LLVM bytecode. Each of these has its strengths and weaknesses, but it is safe to say that none of
the three approaches has yet surpassed the reach and popularity of the standard CPython engine.
的默认实现(被称为CPython)对于⼀些操作执⾏效率很低。这部分归咎于语⾔本⾝的动态和解释执⾏特性:因为类型是动态的,
因此不到执⾏时,⽆法预知变量的类型,因此不能像C或者Fortran那样预先将代码编译成机器代码来执⾏。近年来,也出现了很多尝试来
弥补这个缺陷:其中⽐较流⾏和著名的包括PyPy,Python的JIT编译实现;Cython,可以将Python代码转换为可编译的C代码;和
Numba,可以将Python代码⽚段转换为LLVM字节码。
Python
The relative sluggishness of Python generally manifests itself in situations where many small operations are being
repeated – for instance looping over arrays to operate on each element. For example, imagine we have an array of
values and we'd like to compute the reciprocal of each. A straightforward approach might look like this:
另⼀个表现相对低效的⽅⾯是当重复进⾏很多细微操作时,⽐⽅说对⼀个数组中的每个元素进⾏循环操作。例如,我们有⼀个数
组,现在我们需要计算每个元素的倒数。⼀个很直接的实现⽅式就像下⾯的代码:
Python
In [1]: import numpy as np
np.random.seed(0)
def compute_reciprocals(values):
output = np.empty(len(values))
for i in range(len(values)):
output[i] = 1.0 / values[i]
return output
values = np.random.randint(1, 10, size=5)
compute_reciprocals(values)
Out[1]: array([0.16666667, 1.
, 0.25
, 0.25
, 0.125
])
This implementation probably feels fairly natural to someone from, say, a C or Java background. But if we measure the
execution time of this code for a large input, we see that this operation is very slow, perhaps surprisingly so! We'll
benchmark this with IPython's %timeit magic (discussed in Profiling and Timing Code):
上⾯的代码实现对于很多具有C或者Java语⾔背景的读者来说是⾮常⾃然的。但是如果我们在⼀个很⼤的数据集上测量上⾯代码的执⾏时
间,我们会发现这个操作很慢,甚⾄慢的让你吃惊。下⾯使⽤ %timeit 魔术指令(参⻅性能测算和计时)对⼀个⼤数据集进⾏测时:
In [2]: big_array = np.random.randint(1, 100, size=1000000)
%timeit compute_reciprocals(big_array)
4.07 s ± 68.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
It takes several seconds to compute these million operations and to store the result! When even cell phones have
processing speeds measured in Giga-FLOPS (i.e., billions of numerical operations per second), this seems almost
absurdly slow. It turns out that the bottleneck here is not the operations themselves, but the type-checking and function
dispatches that CPython must do at each cycle of the loop. Each time the reciprocal is computed, Python first examines
the object's type and does a dynamic lookup of the correct function to use for that type. If we were working in compiled
code instead, this type specification would be known before the code executes and the result could be computed much
more efficiently.
这个操作对于百万级的数据集耗时需要⼏秒。当现在⼿机的每秒浮点数运算次数都已经已经达到10亿级别,这实在是不可思议的慢了。通
过分析发现瓶颈并不是代码本⾝,⽽是每次循环时CPython必须执⾏的类型检查和函数匹配。每次计算倒数时,Python⾸先需要检查对象
的类型,然后寻找⼀个最合适的函数对这种类型进⾏计算。如果我们使⽤编译型的语⾔实现上⾯的代码,每次计算的时候,类型和应该执
⾏的函数都已经确定,因此执⾏的时间肯定短很多。
Introducing UFuncs
介绍
UFuncs
For many types of operations, NumPy provides a convenient interface into just this kind of statically typed, compiled
routine. This is known as a vectorized operation. This can be accomplished by simply performing an operation on the
array, which will then be applied to each element. This vectorized approach is designed to push the loop into the
compiled layer that underlies NumPy, leading to much faster execution.
对于许多操作,NumPy都为这种静态类型提供了编译好的函数。被称为向量化的操作。向量化操作可以简单应⽤在数组上,实际上会应⽤
在每⼀个元素上。实现原理就是将循环的部分放进NumPy编译后的那个层次,从⽽提⾼性能。
Compare the results of the following two:
⽐较⼀下下述两种⽅式得到的结果:
In [3]: print(compute_reciprocals(values))
print(1.0 / values)
[0.16666667 1.
[0.16666667 1.
0.25
0.25
0.25
0.25
0.125
0.125
]
]
Looking at the execution time for our big array, we see that it completes orders of magnitude faster than the Python loop:
下⾯使⽤ufuncs来测算执⾏时间,我们可以看到执⾏时间相差了好⼏个数量级:
In [4]: %timeit (1.0 / big_array)
1.53 ms ± 24.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Vectorized operations in NumPy are implemented via ufuncs, whose main purpose is to quickly execute repeated
operations on values in NumPy arrays. Ufuncs are extremely flexible – before we saw an operation between a scalar and
an array, but we can also operate between two arrays:
中的向量化操作是通过ufuncs实现的,其主要⽬的就是在NumPy数组中快速执⾏重复的元素操作。Ufuncs是极端灵活的,我们上⾯
看到是标量和数组间的操作,但是我们也可以将它们⽤在两个数组之间:
NumPy
In [5]: np.arange(5) / np.arange(1, 6)
Out[5]: array([0.
, 0.5
, 0.66666667, 0.75
, 0.8
])
And ufunc operations are not limited to one-dimensional arrays–they can also act on multi-dimensional arrays as well:
⽽且ufuncs也不仅限于⼀维数组,多维数组同样适⽤:
In [6]: x = np.arange(9).reshape((3, 3))
2 ** x
Out[6]: array([[ 1,
2,
4],
[ 8, 16, 32],
[ 64, 128, 256]])
Computations using vectorization through ufuncs are nearly always more efficient than their counterpart implemented
using Python loops, especially as the arrays grow in size. Any time you see such a loop in a Python script, you should
consider whether it can be replaced with a vectorized expression.
通过ufuncs向量化计算基本上都会⽐使⽤Python循环实现的相同⽅法要更加⾼效,特别是数组的⻓度增⻓的情况下。任何情况下,如果你
看到Python的数组循环操作,都可以替换成为向量化形式。
Exploring NumPy's UFuncs
的UFuncs
NumPy
Ufuncs exist in two flavors: unary ufuncs, which operate on a single input, and binary ufuncs, which operate on two
inputs. We'll see examples of both these types of functions here.
Ufuncs
⼦。
有两种类型:⼀元ufuncs(仅对⼀个输⼊值进⾏操作)和⼆元ufuncs(对两个输⼊值进⾏操作)。下⾯我们会看到它们的使⽤例
Array arithmetic
数组运算
NumPy's ufuncs feel very natural to use because they make use of Python's native arithmetic operators. The standard
addition, subtraction, multiplication, and division can all be used:
的
NumPy ufuncs
⽤起来⾮常的⾃然和⼈性化,因为它们采⽤了Python本⾝的算术运算符号 - 标准的加法、剪发、乘法和除法实现:
In [7]: x = np.arange(4)
print("x
=", x)
print("x + 5 =", x + 5)
print("x - 5 =", x - 5)
print("x * 2 =", x * 2)
print("x / 2 =", x / 2)
print("x // 2 =", x // 2)
#
整除
x
= [0 1 2 3]
x + 5 = [5 6 7 8]
x - 5 = [-5 -4 -3 -2]
x * 2 = [0 2 4 6]
x / 2 = [0. 0.5 1. 1.5]
x // 2 = [0 0 1 1]
There is also a unary ufunc for negation, and a ** operator for exponentiation, and a % operator for modulus:
下⾯是⼀元的取反, ** 求幂和 % 取模:
In [8]: print("-x
= ", -x)
print("x ** 2 = ", x ** 2)
print("x % 2 = ", x % 2)
-x
=
x ** 2 =
x % 2 =
[ 0 -1 -2 -3]
[0 1 4 9]
[0 1 0 1]
In addition, these can be strung together however you wish, and the standard order of operations is respected:
当然,你可以将这些运算按照你的需要组合起来,运算顺序与标准运算⼀致:
In [9]: -(0.5*x + 1) ** 2
Out[9]: array([-1.
, -2.25, -4.
, -6.25])
Each of these arithmetic operations are simply convenient wrappers around specific functions built into NumPy; for
example, the + operator is a wrapper for the add function:
上⾯看到的这些算术运算操作,都是NumPy中相应函数的简化写法;例如 + 号实际上是 add 函数的封装:
In [10]: np.add(x, 2)
Out[10]: array([2, 3, 4, 5])
he following table lists the arithmetic operators implemented in NumPy:
下表列出NumPy实现的运算符号及对应的ufunc函数:
运算符
对应的ufunc函数
说明
加法 (例如 1 + 1 = 2 )
np.subtract
减法 (例如 3 - 2 = 1 )
np.negative
⼀元取负 (例如 -2 )
np.multiply
乘法 (例如 2 * 3 = 6 )
np.divide 除法 (例如 3 / 2 = 1.5 )
np.floor_divide
整除 (例如 3 // 2 = 1 )
np.power
求幂 (例如 2 ** 3 = 8 )
np.mod
模除 (例如 9 % 4 = 1 )
+
np.add
*
/
//
**
%
Additionally there are Boolean/bitwise operators; we will explore these in Comparisons, Masks, and Boolean Logic.
除此之外还有布尔和⼆进制位操作;我们会在⽐较,遮盖和布尔逻辑中介绍它们。
Absolute value
绝对值
Just as NumPy understands Python's built-in arithmetic operators, it also understands Python's built-in absolute value
function:
就像NumPy能够理解Python內建的算术操作⼀样,它同样能理解Python內建的绝对值函数:
In [11]: x = np.array([-2, -1, 0, 1, 2])
abs(x)
Out[11]: array([2, 1, 0, 1, 2])
The corresponding NumPy ufunc is np.absolute , which is also available under the alias np.abs :
对应的NumPy的ufunc是 np.absolute ,还有⼀个简短的别名 np.abs :
In [12]: np.absolute(x)
Out[12]: array([2, 1, 0, 1, 2])
In [13]: np.abs(x)
Out[13]: array([2, 1, 0, 1, 2])
This ufunc can also handle complex data, in which the absolute value returns the magnitude:
这个ufunc可以处理复数,返回的是⽮量的⻓度:
In [14]: x = np.array([3 - 4j, 4 - 3j, 2 + 0j, 0 + 1j])
np.abs(x)
Out[14]: array([5., 5., 2., 1.])
Trigonometric functions
三⻆函数
NumPy provides a large number of useful ufuncs, and some of the most useful for the data scientist are the trigonometric
functions. We'll start by defining an array of angles:
提供了⼤量的有⽤的ufuncs,对于数据科学加来说⾮常有⽤的还包括三⻆函数。我们先定义⼀个⻆度的数组:
NumPy
In [15]: theta = np.linspace(0, np.pi, 3)
Now we can compute some trigonometric functions on these values:
然后来计算这个数组的⼀些三⻆函数值:
In [16]: print("theta
= ", theta)
print("sin(theta) = ", np.sin(theta)) #
print("cos(theta) = ", np.cos(theta)) #
print("tan(theta) = ", np.tan(theta)) #
theta
=
sin(theta) =
cos(theta) =
tan(theta) =
正弦
余弦
正切
[0.
1.57079633 3.14159265]
[0.0000000e+00 1.0000000e+00 1.2246468e-16]
[ 1.000000e+00 6.123234e-17 -1.000000e+00]
[ 0.00000000e+00 1.63312394e+16 -1.22464680e-16]
The values are computed to within machine precision, which is why values that should be zero do not always hit exactly
zero. Inverse trigonometric functions are also available:
计算得到的值受到计算机浮点数精度的限制,因为上⾯看到的结果中应该为0的地⽅并不精确的等于0。还提供了逆三⻆函数:
In [17]: x = [-1, 0, 1]
print("x
= ", x)
print("arcsin(x) = ", np.arcsin(x)) #
print("arccos(x) = ", np.arccos(x)) #
print("arctan(x) = ", np.arctan(x)) #
x
=
arcsin(x) =
arccos(x) =
arctan(x) =
反正弦
反余弦
反正切
[-1, 0, 1]
[-1.57079633 0.
1.57079633]
[3.14159265 1.57079633 0.
]
[-0.78539816 0.
0.78539816]
Exponents and logarithms
指数和对数
Another common type of operation available in a NumPy ufunc are the exponentials:
中另⼀种常⽤操作是指数:
NumPy
In [18]: x = [1, 2, 3]
print("x
=", x)
print("e^x
=", np.exp(x))
print("2^x
=", np.exp2(x))
print("3^x
=", np.power(3, x))
x
e^x
2^x
3^x
= [1, 2, 3]
= [ 2.71828183
= [2. 4. 8.]
= [ 3 9 27]
7.3890561
20.08553692]
The inverse of the exponentials, the logarithms, are also available. The basic np.log gives the natural logarithm; if you
prefer to compute the base-2 logarithm or the base-10 logarithm, these are available as well:
指数的逆操作,对数函数。 np.log 求的是⾃然对数;如果你需要计算2的对数或者10的对数,也有相应的函数:
In [19]: x = [1, 2, 4, 10]
print("x
=", x)
print("ln(x)
=", np.log(x))
print("log2(x) =", np.log2(x))
print("log10(x) =", np.log10(x))
x
= [1, 2, 4, 10]
ln(x)
= [0.
0.69314718 1.38629436 2.30258509]
log2(x) = [0.
1.
2.
3.32192809]
log10(x) = [0.
0.30103
0.60205999 1.
]
There are also some specialized versions that are useful for maintaining precision with very small input:
还有当输⼊值很⼩时,可以保持精度的指数和对数函数:
In [20]: x = [0, 0.001, 0.01, 0.1]
print("exp(x) - 1 =", np.expm1(x))
print("log(1 + x) =", np.log1p(x))
exp(x) - 1 = [0.
log(1 + x) = [0.
0.0010005
0.0009995
0.01005017 0.10517092]
0.00995033 0.09531018]
When x is very small, these functions give more precise values than if the raw np.log or np.exp were to be used.
当 x 很⼩时,这些函数会⽐ np.log 或 np.exp 计算得到更加精确的结果。
Specialized ufuncs
特殊的ufuncs
NumPy has many more ufuncs available, including hyperbolic trig functions, bitwise arithmetic, comparison operators,
conversions from radians to degrees, rounding and remainders, and much more. A look through the NumPy
documentation reveals a lot of interesting functionality.
包含更多的ufuncs,包括双曲函数,⼆进制位运算,⽐较操作,⻆度弧度转换,舍⼊以及求余数等等。参考NumPy的在线⽂档你可
以看到很多有趣的函数说明。
NumPy
Another excellent source for more specialized and obscure ufuncs is the submodule scipy.special . If you want to
compute some obscure mathematical function on your data, chances are it is implemented in scipy.special . There
are far too many functions to list them all, but the following snippet shows a couple that might come up in a statistics
context:
在 scipy.special 模块中还有更多的特殊及难懂的ufuncs。如果你需要计算使⽤到晦涩数学函数操作你的数据,基本上你都可以在这个
模块中找到。下⾯列出了部分与数据统计相关的ufuncs,还有很多因为篇幅关系并未列出。
In [21]: from scipy import special
伽玛函数(通⽤阶乘函数)及相关函数
In [22]: #
x = [1, 5, 10]
print("gamma(x)
=", special.gamma(x)) #
print("ln|gamma(x)| =", special.gammaln(x)) #
print("beta(x, 2)
=", special.beta(x, 2)) #
伽玛函数
伽玛函数的⾃然对数
⻉塔函数(第⼀类欧拉积分)
gamma(x)
= [1.0000e+00 2.4000e+01 3.6288e+05]
ln|gamma(x)| = [ 0.
3.17805383 12.80182748]
beta(x, 2)
= [0.5
0.03333333 0.00909091]
误差函数 ⾼斯函数积分
互补误差函数,逆误差函数
In [23]: #
(
)
#
x = np.array([0, 0.3, 0.7, 1.0])
print("erf(x) =", special.erf(x)) #
print("erfc(x) =", special.erfc(x)) #
print("erfinv(x) =", special.erfinv(x)) #
erf(x) = [0.
erfc(x) = [1.
erfinv(x) = [0.
误差函数
互补误差函数
逆误差函数
0.32862676 0.67780119 0.84270079]
0.67137324 0.32219881 0.15729921]
0.27246271 0.73286908
inf]
There are many, many more ufuncs available in both NumPy and scipy.special . Because the documentation of
these packages is available online, a web search along the lines of "gamma function python" will generally find the
relevant information.
还有很多很多ufuncs,你可以在NumPy和 scipy.special 中找到。因为这些函数的⽂档都有在线版本,你可以⽤"gamma函数
python"就可以找到相关的信息。
Advanced Ufunc Features
⾼级Ufunc特性
Many NumPy users make use of ufuncs without ever learning their full set of features. We'll outline a few specialized
features of ufuncs here.
许多NumPy⽤⼾在使⽤ufuncs的时候都没有了解它们完整特性。我们在这⾥会简单介绍⼀些特别的特性。
Specifying output
指定输出
For large calculations, it is sometimes useful to be able to specify the array where the result of the calculation will be
stored. Rather than creating a temporary array, this can be used to write computation results directly to the memory
location where you'd like them to be. For all ufuncs, this can be done using the out argument of the function:
对于⼤数据量的计算,有时指定存储输出数据的数组是很有⽤的。指定输出结果的内存位置能够避免创建临时的数组。所有的ufuncs都能
通过指定 out 参数来指定输出的数组。
In [24]: x = np.arange(5)
y = np.empty(5)
np.multiply(x, 10, out=y) #
print(y)
指定结果存储在y数组中
[ 0. 10. 20. 30. 40.]
This can even be used with array views. For example, we can write the results of a computation to every other element of
a specified array:
输出结果甚⾄可以指定为数组的视图。例如,你可以将结果隔⼀个元素写⼊到⼀个数组中:
In [25]: y = np.zeros(10)
np.power(2, x, out=y[::2]) #
print(y)
[ 1.
0.
2.
0.
4.
0.
8.
指定结果存储在y数组中,每隔⼀个元素存⼀个
0. 16.
0.]
If we had instead written y[::2] = 2 ** x , this would have resulted in the creation of a temporary array to hold the
results of 2 ** x , followed by a second operation copying those values into the y array. This doesn't make much of a
difference for such a small computation, but for very large arrays the memory savings from careful use of the out
argument can be significant.
如果你没使⽤ out 参数,⽽是写成 y[::2] = 2 ** x ,这回导致⾸先创建⼀个临时数组⽤来存储 2 ** x ,然后再将这些值复制到y
数组中。对于上⾯这么⼩的数组来说,其实没有什么区别,但是如果对象是⼀个⾮常⼤的数组,使⽤ out 参数能节省很多内存空间。
Aggregates
聚合
For binary ufuncs, there are some interesting aggregates that can be computed directly from the object. For example, if
we'd like to reduce an array with a particular operation, we can use the reduce method of any ufunc. A reduce
repeatedly applies a given operation to the elements of an array until only a single result remains.
对于⼆元运算ufuncs来说,还有⼀些很有趣的聚合函数可以直接从数组中计算出结果。例如,如果你想 reduce ⼀个数组,你可以对于任
何ufuncs应⽤ reduce ⽅法。reduce会重复在数组的每⼀个元素进⾏ufunc的操作,直到最后得到⼀个标量。
For example, calling reduce on the add ufunc returns the sum of all elements in the array:
例如,在 add ufunc上调⽤ reduce 会返回所有元素的总和:
In [26]: x = np.arange(1, 6)
np.add.reduce(x)
Out[26]: 15
Similarly, calling reduce on the multiply ufunc results in the product of all array elements:
相应的,在 multiply ufunc上调⽤ reduce 会返回所有元素的乘积:
In [27]: np.multiply.reduce(x)
Out[27]: 120
If we'd like to store all the intermediate results of the computation, we can instead use accumulate :
如果你需要得到每⼀步计算得到的中间结果,你可以调⽤ accumulate :
In [28]: np.add.accumulate(x)
Out[28]: array([ 1,
3,
6, 10, 15])
In [29]: np.multiply.accumulate(x)
Out[29]: array([
1,
2,
6,
24, 120])
Note that for these particular cases, there are dedicated NumPy functions to compute the results ( np.sum , np.prod ,
np.cumsum , np.cumprod ), which we'll explore in Aggregations: Min, Max, and Everything In Between.
注意对于上⾯这种特殊情况,NumPy也提供了相应的函数直接计算结果( np.sum , np.prod , np.cumsum , np.cumprod ),我
们会在聚合:Min, Max, 以及其他中详细讨论。
Outer products
外积
Finally, any ufunc can compute the output of all pairs of two different inputs using the outer method. This allows you, in
one line, to do things like create a multiplication table:
最后,任何ufunc都可以计算输⼊的每⼀对元素的结果,使⽤ outer ⽅法。你可以⼀⾏代码就完成类似创建乘法表的功能:
In [30]: x = np.arange(1, 6)
np.multiply.outer(x, x)
Out[30]: array([[ 1, 2, 3, 4, 5],
[ 2, 4, 6, 8, 10],
[ 3, 6, 9, 12, 15],
[ 4, 8, 12, 16, 20],
[ 5, 10, 15, 20, 25]])
The ufunc.at and ufunc.reduceat methods, which we'll explore in Fancy Indexing, are very helpful as well.
和 ufunc.reduceat ⽅法也⾮常有⽤,我们会在⾼级索引中详细讨论。
ufunc.at
Another extremely useful feature of ufuncs is the ability to operate between arrays of different sizes and shapes, a set of
operations known as broadcasting. This subject is important enough that we will devote a whole section to it (see
Computation on Arrays: Broadcasting).
还有⼀个极端有⽤的特性,能让ufuncs在不同⻓度和形状的数组之间进⾏计算,这是⼀组被称为⼴播的⽅法。这是⼀个⾮常重要的
内容,因此我们会专⻔在在数组上计算:⼴播⼩节中进⾏介绍。
Ufuncs
Ufuncs: Learning More
:更多资源
Ufuncs
More information on universal functions (including the full list of available functions) can be found on the NumPy and
SciPy documentation websites.
更多有关ufuncs的信息(包括完整的函数列表)可以在NumPy 和 SciPy的在线⽂档获得。
Recall that you can also access information directly from within IPython by importing the packages and using IPython's
tab-completion and help ( ? ) functionality, as described in Help and Documentation in IPython.
不要忘记了我们可以使⽤IPython的帮助⼯具 ? 来获取任何相关的帮助信息,正如我们在IPython的帮助和⽂档中介绍过的那样。
< Numpy
数组基础 | ⽬录 | 聚合:Min, Max, 以及其他 >
Open in Colab
<
使⽤Numpy计算:通⽤函数 | ⽬录 | 在数组上计算:⼴播 >
Open in Colab
Aggregations: Min, Max, and Everything In Between
聚合:Min,Max和其他
Often when faced with a large amount of data, a first step is to compute summary statistics for the data in question.
Perhaps the most common summary statistics are the mean and standard deviation, which allow you to summarize the
"typical" values in a dataset, but other aggregates are useful as well (the sum, product, median, minimum and maximum,
quantiles, etc.).
通常来说,当我们⾯对⼤量数据时,第⼀步就是计算数据集的概要统计结果。也许最重要的概要统计数据就是平均值和标准差,它们能归
纳出数据集典型的数值,但是其他的聚合函数也很⽤(如求和、乘积、中位值、最⼩值和最⼤值、分位数等)。
NumPy has fast built-in aggregation functions for working on arrays; we'll discuss and demonstrate some of them here.
内建有⾮常快速的函数⽤于计算数组的统计值;本节中我们会讨论其中常⽤的部分。
NumPy
Summing the Values in an Array
在数组中求总和
As a quick example, consider computing the sum of all values in an array. Python itself can do this using the built-in sum
function:
⾸先,我们⽤⼀个简单例⼦来计算数组所有元素值的总和。使⽤Python內建的 sum 函数:
In [1]: import numpy as np
In [2]: L = np.random.random(100)
sum(L)
Out[2]: 54.47499738668567
The syntax is quite similar to that of NumPy's sum function, and the result is the same in the simplest case:
的 sum 函数的语法也差不多,当然,结果也是⼀样的。
NumPy
In [3]: np.sum(L)
Out[3]: 54.47499738668566
However, because it executes the operation in compiled code, NumPy's version of the operation is computed much more
quickly:
然后,因为NumPy的函数是编译执⾏的,因此它的性能会远远超越Python的內建函数:
In [4]: big_array = np.random.rand(1000000)
%timeit sum(big_array)
%timeit np.sum(big_array)
88.3 ms ± 2.84 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
564 µs ± 17.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Be careful, though: the sum function and the np.sum function are not identical, which can sometimes lead to
confusion! In particular, their optional arguments have different meanings, and np.sum is aware of multiple array
dimensions, as we will see in the following section.
要注意的是: sum 内建函数和 np.sum 并不完全相同,这有时会导致混乱。特别的,两个函数的可选参数有着不同的含义,⽽且
np.sum 函数可以处理多维数组运算,我们将在后续章节看到。
Minimum and Maximum
最⼩值和最⼤值
Similarly, Python has built-in min and max functions, used to find the minimum value and maximum value of any given
array:
类似的,Python也有內建 min 和 max 函数,⽤来计算数组的最⼩值和最⼤值:
In [5]: min(big_array), max(big_array)
Out[5]: (2.5903288636275335e-06, 0.9999992774771906)
NumPy's corresponding functions have similar syntax, and again operate much more quickly:
对应的函数也有相似的语法,但是执⾏⾼效很多:
NumPy
In [6]: np.min(big_array), np.max(big_array)
Out[6]: (2.5903288636275335e-06, 0.9999992774771906)
In [7]: %timeit min(big_array)
%timeit np.min(big_array)
61 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
744 µs ± 45.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
For min , max , sum , and several other NumPy aggregates, a shorter syntax is to use methods of the array object
itself:
对于 min , max , sum 和其他NumPy聚合函数来说,也可以通过 ndarray 对象的相应⽅法进⾏调⽤:
In [8]: print(big_array.min(), big_array.max(), big_array.sum())
2.5903288636275335e-06 0.9999992774771906 499718.9807141967
Whenever possible, make sure that you are using the NumPy version of these aggregates when operating on NumPy
arrays!
任何情况下,当你操作NumPy数组时,你都应该使⽤NumPy的聚合函数来代替Python的內建函数。
Multi dimensional aggregates
多维聚合
One common type of aggregation operation is an aggregate along a row or column. Say you have some data stored in a
two-dimensional array:
还有⼀种需求,我们可能需要沿着⾏或列进⾏聚合。⽐⽅说你有⼀个⼆维数组:
In [9]: M = np.random.random((3, 4))
print(M)
[[0.27614977 0.75224804 0.69322493 0.55140476]
[0.6698524 0.92722784 0.20959198 0.78538042]
[0.05078426 0.58621268 0.7614707 0.77247016]]
By default, each NumPy aggregation function will return the aggregate over the entire array:
默认情况下,NumPy聚合函数都会返回整个数组的聚合结果标量:
In [10]: M.sum()
Out[10]: 7.03601794886263
Aggregation functions take an additional argument specifying the axis along which the aggregate is computed. For
example, we can find the minimum value within each column by specifying axis=0 :
聚合函数可以接收⼀个额外的参数指定⼀个轴让函数沿着这个⽅向进⾏聚合运算。例如,我们可以沿着⾏的⽅向计算每列的最⼩值,通过
指定 axis=0 参数即可:
In [11]: M.min(axis=0)
Out[11]: array([0.05078426, 0.58621268, 0.20959198, 0.55140476])
The function returns four values, corresponding to the four columns of numbers.
这个函数返回四个值,对应着四列。
Similarly, we can find the maximum value within each row:
类似的,我们也可以计算每⼀⾏的最⼤值:
In [12]: M.max(axis=1)
Out[12]: array([0.75224804, 0.92722784, 0.77247016])
The way the axis is specified here can be confusing to users coming from other languages. The axis keyword
specifies the dimension of the array that will be collapsed, rather than the dimension that will be returned. So specifying
axis=0 means that the first axis will be collapsed: for two-dimensional arrays, this means that values within each
column will be aggregated.
上述指定axis参数的⽅式可能会让具有其他编程语⾔基础的⽤⼾感到不适应。这⾥的 axis 参数指定的是让数组沿着这个⽅向进⾏压缩,
⽽不是指定返回值的⽅向。因此指定 axis=0 意味着第⼀个维度将被压缩:对于⼀个⼆维数组来说,就是数组将沿着列的⽅向进⾏聚合运
算操作。
Other aggregation functions
其他聚合函数
NumPy provides many other aggregation functions, but we won't discuss them in detail here. Additionally, most
aggregates have a NaN -safe counterpart that computes the result while ignoring missing values, which are marked by
the special IEEE floating-point NaN value (for a fuller discussion of missing data, see Handling Missing Data). Some of
these NaN -safe functions were not added until NumPy 1.8, so they will not be available in older NumPy versions.
提供了许多其他聚合函数,但是我们不会在这⾥详细讨论它们。需要说明的是,很多聚合函数都有⼀个 NaN 安全的版本,可以忽
略空缺的数据并计算得到正确的结果。 NaN 即为IEEE标准中浮点数⾮数值的定义(完整的讨论空缺数据的内容请参⻅处理空缺数据)。
部分 NaN 安全的函数版本是在NumPy 1.8之后加⼊的,因此在⽼版本的NumPy中可能⽆法使⽤。
NumPy
The following table provides a list of useful aggregation functions available in NumPy:
下表列出了NumPy中有⽤的聚合函数:
函数名称
np.sum
np.prod
np.mean
np.std
np.var
np.min
np.max
np.argmin
np.argmax
np.median
np.percentile
np.any
np.all
安全版本
说明
np.nansum
计算总和
np.nanprod
计算乘积
np.nanmean
计算平均值
np.nanstd
计算标准差
np.nanvar
计算⽅差
np.nanmin
计算最⼩值
np.nanmax
计算最⼤值
np.nanargmin
寻找最⼩值的序号
np.nanargmax
寻找最⼤值的序号
np.nanmedian
计算中位值
np.nanpercentile 计算百分⽐分布的对应值
N/A
是否含有True值
N/A
是否全为True值
NaN
We will see these aggregates often throughout the rest of the book.
我们在本书后续内容中会经常看到这些聚合函数。
Example: What is the Average Height of US Presidents?
例⼦:美国总统的平均⾝⾼?
Aggregates available in NumPy can be extremely useful for summarizing a set of values. As a simple example, let's
consider the heights of all US presidents. This data is available in the file president_heights.csv, which is a simple
comma-separated list of labels and values:
在NumPy中使⽤聚合统计来对⼀个数据集进⾏概要说明是⾮常有⽤的。下⾯我们使⽤美国总统的⾝⾼作为⼀个简单的例⼦来说明。这些数
据存储在⽂件president_heights.csv⾥,⽂件格式就是简单的逗号分隔的⽂本⽂件:
In [13]: !head -4 data/president_heights.csv
order,name,height(cm)
1,George Washington,189
2,John Adams,170
3,Thomas Jefferson,189
We'll use the Pandas package, which we'll explore more fully in Chapter 3, to read the file and extract this information
(note that the heights are measured in centimeters).
我们会使⽤Pandas包来读取⽂件和提取数据(注意⾝⾼单位是厘⽶),Pandas的相关内容我们会在第三章中详细介绍。
In [14]: import pandas as pd
data = pd.read_csv('data/president_heights.csv')
heights = np.array(data['height(cm)'])
print(heights)
[189 170 189 163 183 171 185 168 173 183 173 173 175 178 183 193 178 173
174 183 183 168 170 178 182 180 183 178 182 188 175 179 183 193 182 183
177 185 188 188 182 185]
Now that we have this data array, we can compute a variety of summary statistics:
获得了NumPy数组之后,我们就能计算各种的基本统计数据了:
In [15]: print("Mean height:
", heights.mean()) # ⾝⾼平均值
print("Standard deviation:", heights.std()) # 标准差
print("Minimum height:
", heights.min()) # 最⼩值
print("Maximum height:
", heights.max()) # 最⼤值
Mean height:
179.73809523809524
Standard deviation: 6.931843442745892
Minimum height:
163
Maximum height:
193
Note that in each case, the aggregation operation reduced the entire array to a single summarizing value, which gives us
information about the distribution of values. We may also wish to compute quantiles:
上述结果中,每个聚合函数都将整个数组计算后得到⼀个标量值,可以让我们初步了解数据的基本分布信息。下⾯来计算分位值:
In [16]: print("25th percentile:
print("Median:
print("75th percentile:
25th percentile:
Median:
75th percentile:
分位值
分位值 中位值
分位值
", np.percentile(heights, 25)) # 25%
", np.median(heights)) # 50%
", np.percentile(heights, 75)) # 75%
174.25
182.0
183.0
We see that the median height of US presidents is 182 cm, or just shy of six feet.
我们看到美国总统⾝⾼的中位值是182厘⽶,也就是6英尺。
Of course, sometimes it's more useful to see a visual representation of this data, which we can accomplish using tools in
Matplotlib (we'll discuss Matplotlib more fully in Chapter 4). For example, this code generates the following chart:
当然,有时对数据进⾏图表展⽰会更加直观,我们可以通过Matplotlib⼯具进⾏(Matplotlib的知识会在第四章详细介绍)。例如,下述代码
产⽣相应的图表:
In [17]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() #
设置图表的⻛格为seaborn
In [18]: plt.hist(heights)
plt.title('Height Distribution of US Presidents')
plt.xlabel('height (cm)')
plt.ylabel('number');
These aggregates are some of the fundamental pieces of exploratory data analysis that we'll explore in more depth in
later chapters of the book.
这些聚合数据提供了我们对于数据集最基本的理解,我们会在本书后续章节更加深⼊的讨论它们。
<
使⽤Numpy计算:通⽤函数 | ⽬录 | 在数组上计算:⼴播 >
Open in Colab
<
聚合:Min, Max, 以及其他 | ⽬录 | ⽐较,遮盖和布尔逻辑 >
Open in Colab
Computation on Arrays: Broadcasting
在数组上计算:⼴播
We saw in the previous section how NumPy's universal functions can be used to vectorize operations and thereby
remove slow Python loops. Another means of vectorizing operations is to use NumPy's broadcasting functionality.
Broadcasting is simply a set of rules for applying binary ufuncs (e.g., addition, subtraction, multiplication, etc.) on arrays
of different sizes.
我们在前⾯的章节中学习了NumPy的通⽤函数,它们⽤来对数组进⾏向量化操作,从⽽抛弃了性能低下的Python循环。还有⼀种对
NumPy数组进⾏向量化操作的⽅式我们称为⼴播。⼴播简单来说就是⼀整套⽤于在不同尺⼨或形状的数组之间进⾏⼆元ufuncs运算(如加
法、减法、乘法等)的规则。
Introducing Broadcasting
⼴播简介
Recall that for arrays of the same size, binary operations are performed on an element-by-element basis:
回忆⼀下对于相同尺⼨的数组来说,⼆元运算是按每个元素进⾏运算的:
In [1]: import numpy as np
In [2]: a = np.array([0, 1, 2])
b = np.array([5, 5, 5])
a + b
Out[2]: array([5, 6, 7])
Broadcasting allows these types of binary operations to be performed on arrays of different sizes–for example, we can
just as easily add a scalar (think of it as a zero-dimensional array) to an array:
⼴播机制允许这样的⼆元运算能够在不同尺⼨和形状的数组之间进⾏,例如,我们可以⽤数组和⼀个标量相加(标量可以认为是⼀个零维
数组):
In [3]: a + 5
Out[3]: array([5, 6, 7])
We can think of this as an operation that stretches or duplicates the value 5 into the array [5, 5, 5] , and adds the
results. The advantage of NumPy's broadcasting is that this duplication of values does not actually take place, but it is a
useful mental model as we think about broadcasting.
我们可以认为上⾯的运算⾸先将标量扩展成了⼀个⼀维的数组 [5, 5, 5] ,然后在和 a 进⾏了加法运算。NumPy的⼴播⽅式并不是真
的需要将元素复制然后扩展,但是这对于理解⼴播的运⾏⽅式很有帮助。
We can similarly extend this to arrays of higher dimension. Observe the result when we add a one-dimensional array to a
two-dimensional array:
我们可以很简单的将上⾯的情形推⼴到更⾼纬度的数组上。下⾯我们使⽤⼴播将⼀个⼀维数组和⼀个⼆维数组进⾏加法运算:
In [4]: M = np.ones((3, 3))
M
Out[4]: array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
In [5]: M + a
Out[5]: array([[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]])
Here the one-dimensional array a is stretched, or broadcast across the second dimension in order to match the shape
of M .
上例中⼀维数组 a 在第⼆个维度上进⾏了扩展或者⼴播,这样才能符合 M 的形状。
While these examples are relatively easy to understand, more complicated cases can involve broadcasting of both
arrays. Consider the following example:
上⾯两个例⼦相对来说⾮常容易理解,但是当参与运算的两个数组都需要⼴播时,情况就相对复杂⼀些了。看下⾯的例⼦:
In [6]: a = np.arange(3)
b = np.arange(3)[:, np.newaxis]
print(a)
print(b)
[0 1 2]
[[0]
[1]
[2]]
In [7]: a + b
Out[7]: array([[0, 1, 2],
[1, 2, 3],
[2, 3, 4]])
Just as before we stretched or broadcasted one value to match the shape of the other, here we've stretched both a and
b to match a common shape, and the result is a two-dimensional array! The geometry of these examples is visualized
in the following figure (Code to produce this plot can be found in the appendix, and is adapted from source published in
the astroML documentation. Used by permission).
前⾯例⼦中我们只对其中⼀个数组进⾏了扩展或者⼴播,上例中我们需要对 a 和 b 两个数组都进⾏⼴播才能满⾜双⽅是相同的形状,最后
的结果是⼀个⼆维的数组。上⾯例⼦可以⽤下⾯的图来进⾏说明(⽣成这些图像的代码可以在附录中找到,其中部分使⽤了经过授权的
astroML⽹站⽂档中的代码)。
Broadcasting Visual
The light boxes represent the broadcasted values: again, this extra memory is not actually allocated in the course of the
operation, but it can be useful conceptually to imagine that it is.
浅⾊格⼦代表的是⼴播后的值:再次说明,这些⼴播的值不会真正占⽤内存,只是为了辅助我们理解⼴播的机制。
Rules of Broadcasting
⼴播的规则
Broadcasting in NumPy follows a strict set of rules to determine the interaction between the two arrays:
Rule 1: If the two arrays differ in their number of dimensions, the shape of the one with fewer dimensions is padded
with ones on its leading (left) side.
Rule 2: If the shape of the two arrays does not match in any dimension, the array with shape equal to 1 in that
dimension is stretched to match the other shape.
Rule 3: If in any dimension the sizes disagree and neither is equal to 1, an error is raised.
在NumPy中应⽤⼴播不是随意的,⽽是需要遵从严格的⼀套规则:
规则1:如果两个数组有着不同的维度,维度较⼩的那个数组会沿着最前(或最左)的维度进⾏扩增,扩增的维度尺⼨为1,这时两个
数组具有相同的维度。
规则2:如果两个数组形状在任何某个维度上存在不相同,那么两个数组中形状为1的维度都会⼴播到另⼀个数组对应唯独的尺⼨,最
终双⽅都具有相同的形状。
规则3:如果两个数组在同⼀个维度上具有不为1的不同⻓度,那么将产⽣⼀个错误。
To make these rules clear, let's consider a few examples in detail.
为了说明⽩这些规则,我们需要参考下⾯的⼀些例⼦:
Broadcasting example 1
⼴播规则例⼦1
Let's look at adding a two-dimensional array to a one-dimensional array:
我们先看⼀下⼀个⼆维数组和⼀个⼀维数组相加:
In [8]: M = np.ones((2, 3))
a = np.arange(3)
Let's consider an operation on these two arrays. The shape of the arrays are
我们先看⼀下两个数组的形状:
M.shape = (2, 3)
a.shape = (3,)
We see by rule 1 that the array a has fewer dimensions, so we pad it on the left with ones:
依据规则1,数组 a 的维度较少,因此⾸先对其进⾏维度扩增,我们在其最前⾯(最左边)增加⼀个维度,⻓度为1。此时两个数组的形状
变为:
M.shape -> (2, 3)
a.shape -> (1, 3)
By rule 2, we now see that the first dimension disagrees, so we stretch this dimension to match:
依据规则2,我们可以看到双⽅在第⼀维度上不相同,因此我们将第⼀维度具有⻓度1的 a 的第⼀维度扩展为2。此时双⽅的形状变为:
M.shape -> (2, 3)
a.shape -> (2, 3)
The shapes match, and we see that the final shape will be (2, 3) :
经过变换之后,双⽅形状⼀致,可以进⾏加法运算了,我们可以预知最终结果的形状为 (2, 3) :
In [9]: M + a
Out[9]: array([[1., 2., 3.],
[1., 2., 3.]])
Broadcasting example 2
⼴播规则例⼦2
Let's take a look at an example where both arrays need to be broadcast:
让我们看⼀个两个数组都需要⼴播的情况:
In [10]: a = np.arange(3).reshape((3, 1))
b = np.arange(3)
Again, we'll start by writing out the shape of the arrays:
开始时双⽅的形状为:
a.shape = (3, 1)
b.shape = (3,)
Rule 1 says we must pad the shape of b with ones:
由规则1我们需要将数组 b 扩增第⼀维度,⻓度为1:
a.shape -> (3, 1)
b.shape -> (1, 3)
And rule 2 tells us that we upgrade each of these ones to match the corresponding size of the other array:
由规则2我们需要将数组 a 的第⼆维度扩展为3,还需要将数组 b 的第⼀维度扩展为3,得到:
a.shape -> (3, 3)
b.shape -> (3, 3)
Because the result matches, these shapes are compatible. We can see this here:
双⽅形状相同,可以进⾏运算:
In [11]: a + b
Out[11]: array([[0, 1, 2],
[1, 2, 3],
[2, 3, 4]])
Broadcasting example 3
⼴播规则例⼦3
Now let's take a look at an example in which the two arrays are not compatible:
现在我们来看⼀个不能适⽤于⼴播的例⼦:
In [12]: M = np.ones((3, 2))
a = np.arange(3)
This is just a slightly different situation than in the first example: the matrix M is transposed. How does this affect the
calculation? The shape of the arrays are
这个例⼦和例⼦1有⼀点点区别,那就是本例中的 M 是例⼦1中 M 的转置矩阵。它们的形状是:
M.shape = (3, 2)
a.shape = (3,)
Again, rule 1 tells us that we must pad the shape of a with ones:
由规则1我们需要在数组 a 上扩增第⼀维度,⻓度为1:
M.shape -> (3, 2)
a.shape -> (1, 3)
By rule 2, the first dimension of a is stretched to match that of M :
由规则2我们需要将数组 a 的第⼀维度扩展为3才能与数组 M 保持⼀致,除此之外双⽅都没有⻓度为1的维度了:
M.shape -> (3, 2)
a.shape -> (3, 3)
Now we hit rule 3–the final shapes do not match, so these two arrays are incompatible, as we can observe by attempting
this operation:
观察得到的形状,你可以发现这个结果满⾜规则3,双⽅的各维度⻓度不完全⼀致且不为1,因此⽆法完成⼴播,最终会产⽣错误:
In [13]: M + a
--------------------------------------------------------------------------ValueError
Traceback (most recent call last)
<ipython-input-13-8cac1d547906> in <module>
----> 1 M + a
ValueError: operands could not be broadcast together with shapes (3,2) (3,)
Note the potential confusion here: you could imagine making a and M compatible by, say, padding a 's shape with
ones on the right rather than the left. But this is not how the broadcasting rules work! That sort of flexibility might be
useful in some cases, but it would lead to potential areas of ambiguity. If right-side padding is what you'd like, you can do
this explicitly by reshaping the array (we'll use the np.newaxis keyword introduced in The Basics of NumPy Arrays):
这⾥你可能会发现⼀个问题:如果⼴播的时候不⼀定按照最前⾯(最左边)维度的原则进⾏扩增维度的话,那不是很多的数组都可以进⾏
⼴播计算吗?这样处理不是更灵活吗?例如上例中如果我们在数组 a 的第⼆维度上扩增的话,那⼴播就能正确进⾏了。很可惜,⼴播并不
会⽀持这种处理⽅式,虽然这种⽅法在某些情况下会更加灵活,但是在部分情况下会带来不确定性。如果你确实希望进⾏右维度扩增的
话,你必须明确指定。利⽤我们在NumPy数组基础中介绍的 np.newaxis 属性可以进⾏这个操作:
In [14]: a[:, np.newaxis].shape
Out[14]: (3, 1)
In [15]: M + a[:, np.newaxis]
Out[15]: array([[1., 1.],
[2., 2.],
[3., 3.]])
Also note that while we've been focusing on the + operator here, these broadcasting rules apply to any binary ufunc .
For example, here is the logaddexp(a, b) function, which computes log(exp(a) + exp(b)) with more
precision than the naive approach:
还要说明的是,上⾯的例⼦中我们都是使⽤加法进⾏说明,实际上⼴播可以应⽤到任何的⼆元ufunc上。例如下⾯我们采⽤
logaddexp(a, b) 函数求值,这个函数计算的是
的值,使⽤这个函数能⽐采⽤原始的exp和log函数进⾏计算得到更⾼的
精度:
a
log( e
b
+e )
In [16]: np.logaddexp(M, a[:, np.newaxis])
Out[16]: array([[1.31326169, 1.31326169],
[1.69314718, 1.69314718],
[2.31326169, 2.31326169]])
For more information on the many available universal functions, refer to Computation on NumPy Arrays: Universal
Functions.
更多关于通⽤函数的介绍,请复习使⽤Numpy计算:通⽤函数。
Broadcasting in Practice
⼴播规则实践
Broadcasting operations form the core of many examples we'll see throughout this book. We'll now take a look at a
couple simple examples of where they can be useful.
⼴播操作在本书后⾯很多例⼦中都会⻅到。因此这⾥我们看⼀些简单的例⼦,更好的说明它。
Centering an array
中⼼化数组
In the previous section, we saw that ufuncs allow a NumPy user to remove the need to explicitly write slow Python loops.
Broadcasting extends this ability. One commonly seen example is when centering an array of data. Imagine you have an
array of 10 observations, each of which consists of 3 values. Using the standard convention (see Data Representation in
Scikit-Learn), we'll store this in a 10 × 3 array:
在前⼀节中,我们看到了ufuncs提供了我们可以避免使⽤Python循环的低效⽅式,⽽⼴播则⼤⼤扩展了这种能⼒。⼀个常⻅的例⼦就是我
们需要将数据集进⾏中⼼化。例如我们我们进⾏了10次采样观测,每次都会得到3个数据值。按照惯例(参⻅Scikit-Learn数据表现⽅
式),我们可以将这些数据存成⼀个
的数组:
10 × 3
In [17]: X = np.random.random((10, 3))
We can compute the mean of each feature using the mean aggregate across the first dimension:
我们使⽤ mean 函数沿着第⼀维度求出每个特征的平均值:
In [18]: Xmean = X.mean(0)
Xmean
Out[18]: array([0.61839754, 0.51852053, 0.65514576])
And now we can center the X array by subtracting the mean (this is a broadcasting operation):
下⾯我们就可以将数组 X 减去它的各维度平均值就可以将其中⼼化(这⾥就是⼀个⼴播操作):
In [19]: X_centered = X - Xmean
To double-check that we've done this correctly, we can check that the centered array has near zero mean:
我们来检查⼀下结果的正确性,我们可以通过查看中⼼化后的数组在各特征上的平均值是够接近于0来进⾏判断:
In [20]: X_centered.mean(0)
Out[20]: array([-1.11022302e-17, -7.77156117e-17,
6.66133815e-17])
To within machine precision, the mean is now zero.
考虑到机器精度情况,平均值已经等于0了。
Plotting a two-dimensional function
绘制⼆维函数的图形
One place that broadcasting is very useful is in displaying images based on two-dimensional functions. If we want to
define a function z = f(x, y), broadcasting can be used to compute the function across the grid:
⼴播还有⼀个很有⽤的场景,就是当你需要绘制⼀个⼆维函数的图像时。如果我们希望定义⼀个函数
维平⾯上每个⽹格的数值:
In [21]: # x和y都是0~5范围平均分的50个点
z = f(x, y)
,⼴播可以被⽤来计算⼆
x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 50)[:, np.newaxis]
z = np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
We'll use Matplotlib to plot this two-dimensional array (these tools will be discussed in full in Density and Contour Plots):
算出z后,我们使⽤Matplotlib来画出这个⼆维数组(我们将在密度和轮廓图中详细介绍):
In [22]: %matplotlib inline
import matplotlib.pyplot as plt
In [23]: plt.imshow(z, origin='lower', extent=[0, 5, 0, 5],
cmap='viridis')
plt.colorbar();
The result is a compelling visualization of the two-dimensional function.
上⾯的图形以⼀种极其吸引⼈的⽅式为我们展现了⼆维函数的分布情况。
<
聚合:Min, Max, 以及其他 | ⽬录 | ⽐较,遮盖和布尔逻辑 >
Open in Colab
<
在数组上计算:⼴播 | ⽬录 | ⾼级索引 >
Open in Colab
Comparisons, Masks, and Boolean Logic
⽐较,遮盖和布尔逻辑
This section covers the use of Boolean masks to examine and manipulate values within NumPy arrays. Masking comes
up when you want to extract, modify, count, or otherwise manipulate values in an array based on some criterion: for
example, you might wish to count all values greater than a certain value, or perhaps remove all outliers that are above
some threshold. In NumPy, Boolean masking is often the most efficient way to accomplish these types of tasks.
本⼩节将介绍使⽤布尔遮盖(掩码)来测试和操作NumPy数组的知识。当我们想通过⼀些标准对数组中的元素值进⾏提取、修改、计数或
者其他⼀些操作的时候,我们需要使⽤遮盖:例如,你需要计算所有⼤于某个特定值的元素个数,或者删除那些超出阈值的离群值。在
NumPy当中,布尔遮盖基本上是实现这类任务的最有效⽅式。
Example: Counting Rainy Days
例⼦:计算下⾬的天数
Imagine you have a series of data that represents the amount of precipitation each day for a year in a given city. For
example, here we'll load the daily rainfall statistics for the city of Seattle in 2014, using Pandas (which is covered in more
detail in Chapter 3):
设想你有⼀系列数据代表着某个城市⼀年中每天的降⽔量。例如,下⾯我们将使⽤Pandas读取2014年西雅图的每天降⾬统计数据
(Pandas我们将在第三章详细介绍):
In [1]: import numpy as np
import pandas as pd
使⽤
读取降⽔量以英⼨为单位的数据
毫⽶转换成英⼨
#
Pandas
rainfall = pd.read_csv('data/Seattle2014.csv')['PRCP'].values
inches = rainfall / 254.0 # 0.1
inches.shape
Out[1]: (365,)
The array contains 365 values, giving daily rainfall in inches from January 1 to December 31, 2014.
这个数组包含着365个元素值,这些值代表着西雅图市2014年从1⽉1⽇到12⽉31⽇的降⾬(单位英⼨)。
As a first quick visualization, let's look at the histogram of rainy days, which was generated using Matplotlib (we will
explore this tool more fully in Chapter 4):
我们使⽤图表可视化展⽰⼀下,⽤简单的直⽅图来画出降⾬天数的分布情况。这⾥需要使⽤到Matplotlib(有关内容我们将在第四章详细介
绍):
In [2]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() #
设置图表的⻛格,seaborn
In [3]: plt.hist(inches, 40); # 将降⽔量区间40等分作为横轴,将落在区间的元素个数作为纵轴
This histogram gives us a general idea of what the data looks like: despite its reputation, the vast majority of days in
Seattle saw near zero measured rainfall in 2014. But this doesn't do a good job of conveying some information we'd like
to see: for example, how many rainy days were there in the year? What is the average precipitation on those rainy days?
How many days were there with more than half an inch of rain?
上⾯的直⽅图给我们提供了⼀个对这个数据集的通⽤观察结论:虽然名声在外,但事实上西雅图在2014年中绝⼤部分⽇⼦的降⾬量都接近
于0。但是这张图并没有帮助我们了解⼀些我们希望得到的数据:例如,⼀年之中有多少天在下⾬?下⾬的⽇⼦中降⽔量的平均值是多少?
⼀年之中有多少天降⽔量超过半英⼨?
Digging into the data
挖掘数据
One approach to this would be to answer these questions by hand: loop through the data, incrementing a counter each
time we see values in some desired range. For reasons discussed throughout this chapter, such an approach is very
inefficient, both from the standpoint of time writing code and time computing the result. We saw in Computation on
NumPy Arrays: Universal Functions that NumPy's ufuncs can be used in place of loops to do fast element-wise arithmetic
operations on arrays; in the same way, we can use other ufuncs to do element-wise comparisons over arrays, and we
can then manipulate the results to answer the questions we have. We'll leave the data aside for right now, and discuss
some general tools in NumPy to use masking to quickly answer these types of questions.
有⼀种⽅法我们已经掌握了:循环遍历数据,然后对每个元素的值进⾏判断是否处在相应的范围。在前⾯的⼩节中,我们已经解释了为什
么这种⽅式是低效的原因,⽆论从写代码花的时间来看还是从计算结果需要的时间来看。在使⽤Numpy计算:通⽤函数⼩节中,我们学习
了NumPy的ufuncs可以⽤来替代循环进⾏逐个元素的算术计算;同样的,我们也可以使⽤其他的ufuncs来对每个元素进⾏⽐较运算,通过
这种⽅法我们就可以很简单的回答上⾯问题。我们暂且放下例⼦的数据,先介绍⼀些NumPy中⽤来进⾏遮盖的通⽤⼯具,适合这种任务的
处理。
Comparison Operators as ufuncs
的⽐较运算符
UFuncs
In Computation on NumPy Arrays: Universal Functions we introduced ufuncs, and focused in particular on arithmetic
operators. We saw that using + , - , * , / , and others on arrays leads to element-wise operations. NumPy also
implements comparison operators such as < (less than) and > (greater than) as element-wise ufuncs. The result of
these comparison operators is always an array with a Boolean data type. All six of the standard comparison operations
are available:
在使⽤Numpy计算:通⽤函数⼩节中,我们介绍了ufuncs,⽽且主要集中介绍了算术运算符。我们知道可以使⽤ + 、 - 、 * 、 / 和其他
的运算可以对数组进⾏逐个元素的运算操作。NumPy同样也实现了⽐较运算符如 < (⼩于)和 > (⼤于)的ufuncs。这些⽐较运算符的
结算结果⼀定是⼀个布尔类型的数组。全部6种标准的⽐较运算都是⽀持的:
In [4]: x = np.array([1, 2, 3, 4, 5])
In [5]: x < 3
# less than
Out[5]: array([ True,
In [6]: x > 3
True, False, False, False])
# greater than
Out[6]: array([False, False, False,
In [7]: x <= 3
True,
True])
# less than or equal
Out[7]: array([ True,
In [8]: x >= 3
True,
True, False, False])
# greater than or equal
Out[8]: array([False, False,
In [9]: x != 3
True,
True,
True])
True, False,
True,
True])
# not equal
Out[9]: array([ True,
In [10]: x == 3
# equal
Out[10]: array([False, False,
True, False, False])
It is also possible to do an element-wise comparison of two arrays, and to include compound expressions:
也可以对两个数组的每个元素进⾏⽐较,还⽀持运算的组合操作:
In [11]: (2 * x) == (x ** 2)
Out[11]: array([False,
True, False, False, False])
As in the case of arithmetic operators, the comparison operators are implemented as ufuncs in NumPy; for example,
when you write x < 3 , internally NumPy uses np.less(x, 3) . A summary of the comparison operators and their
equivalent ufunc is shown here:
就像算术运算符⼀样,⽐较运算符实际上也是NumPy的ufuncs的简写⽅式;例如,当你写 x < 3 的时候,实际上调⽤的是NumPy的
np.less(x, 3) 。⼩标列出了⽐较运算符及其对应的ufuncs:
运算符
相应的ufunc 运算符
相应的ufunc
==
np.equal
!=
np.not_equal
<
np.less
<=
np.less_equal
>
np.greater
>=
np.greater_equal
Just as in the case of arithmetic ufuncs, these will work on arrays of any size and shape. Here is a two-dimensional
example:
如同算术运算ufuncs,⽐较运算也能应⽤在任何⻓度任何形状的数组上。下⾯是⼀个⼆维数组例⼦:
In [12]: rng = np.random.RandomState(0)
x = rng.randint(10, size=(3, 4))
x
Out[12]: array([[5, 0, 3, 3],
[7, 9, 3, 5],
[2, 4, 7, 6]])
In [13]: x < 6
Out[13]: array([[ True, True, True, True],
[False, False, True, True],
[ True, True, False, False]])
In each case, the result is a Boolean array, and NumPy provides a number of straightforward patterns for working with
these Boolean results.
在任何的情况下,结果都是⼀个布尔类型数组,NumPy还提供了数量众多的函数能够直接对这些布尔数组进⾏操作。
Working with Boolean Arrays
操作布尔数组
Given a Boolean array, there are a host of useful operations you can do. We'll work with x , the two-dimensional array
we created earlier.
对于⼀个布尔数组,你可以进⾏许多有⽤的操作。我们继续使⽤上⾯我们创建的⼆维数组 x 来说明。
In [14]: print(x)
[[5 0 3 3]
[7 9 3 5]
[2 4 7 6]]
Counting entries
计算元素个数
To count the number of True entries in a Boolean array, np.count_nonzero is useful:
要计算⼀个布尔数组的真值 True 元素的个数, np.count_nonzero 可以做到:
有多少个元素⼩于 ?
In [15]: #
6
np.count_nonzero(x < 6)
Out[15]: 8
We see that there are eight array entries that are less than 6. Another way to get at this information is to use np.sum ; in
this case, False is interpreted as 0 , and True is interpreted as 1 :
我们可以看到数组当中有8个元素的值⼩于6.另⼀种可选的⽅法是使⽤ np.sum ;因为在Python中, False 实际上代表0,⽽ True 实际
上代表1:
In [16]: np.sum(x < 6)
Out[16]: 8
The benefit of sum() is that like with other NumPy aggregation functions, this summation can be done along rows or
columns as well:
使⽤ sum() 函数的好处是它的使⽤就像NumPy的聚合函数⼀样,可以沿着不同的维度进⾏计算(如⾏或列):
在每⼀⾏中有多少个元素⼩于6?
In [17]: #
np.sum(x < 6, axis=1)
Out[17]: array([4, 2, 2])
This counts the number of values less than 6 in each row of the matrix.
上例计算了矩阵中每⼀⾏中⼩于6的元素的个数。
If we're interested in quickly checking whether any or all the values are true, we can use (you guessed it) np.any or
np.all :
如果我们关⼼的问题是,是否有任何的元素值或全部的元素值为True,我们可以使⽤ np.any 或 np.all :
有没有任何⼀个元素⼤于8?
In [18]: #
np.any(x > 8)
Out[18]: True
有没有任何元素⼩于0
In [19]: #
np.any(x < 0)
Out[19]: False
所有的元素都⼩于 ?
In [20]: #
10
np.all(x < 10)
Out[20]: True
所有的元素都等于 ?
In [21]: #
6
np.all(x == 6)
Out[21]: False
np.all and np.any can be used along particular axes as well. For example:
和 np.any 也可以沿着特定的轴进⾏运算,例如:
np.all
是否每⼀⾏的所有值都⼩于 ?
In [22]: #
8
np.all(x < 8, axis=1)
Out[22]: array([ True, False,
True])
Here all the elements in the first and third rows are less than 8, while this is not the case for the second row.
上例结果表明,第⼀⾏和第三⾏所有的元素值都⼩于8,⽽第⼆⾏却不满⾜。
Finally, a quick warning: as mentioned in Aggregations: Min, Max, and Everything In Between, Python has built-in
sum() , any() , and all() functions. These have a different syntax than the NumPy versions, and in particular will
fail or produce unintended results when used on multidimensional arrays. Be sure that you are using np.sum() ,
np.any() , and np.all() for these examples!
最后提醒⼀下:就像在聚合:Min, Max, 以及其他中提⽰过的⼀样,Python也有內建的 sum() 、 any() 和 all() 函数。它们和NumPy
对应的函数有着不同的语法,特别是应⽤在多维数组进⾏计算时,会得到错误和⽆法预料的结果。你需要保证使⽤NumPy提供的函数来进
⾏相应的运算。
Boolean operators
布尔运算符
We've already seen how we might count, say, all days with rain less than four inches, or all days with rain greater than
two inches. But what if we want to know about all days with rain less than four inches and greater than one inch? This is
accomplished through Python's bitwise logic operators, & , | , ^ , and ~ . Like with the standard arithmetic operators,
NumPy overloads these as ufuncs which work element-wise on (usually Boolean) arrays.
我们已经学习到了如何计算⾬量⼩于4英⼨的天数或者⾬量⼤于2英⼨的天数。但是如果我们期望的结果是⾬量⼩于4英⼨并且⼤于1英⼨的
天数,该怎么做?这可以通过Python的位运算符来实现,包括 & 、 | 、 ^ 和 ~ 。就像普通的算术运算符⼀样,NumPy重载了这些符号
作为ufuncs,可以在数组(通常是布尔数组)每个元素值上进⾏位操作。
For example, we can address this sort of compound question as follows:
例如,我们可以进⾏下⾯这个复合运算操作:
In [23]: np.sum((inches > 0.5) & (inches < 1))
Out[23]: 29
So we see that there are 29 days with rainfall between 0.5 and 1.0 inches.
从结果我们得出结论,⾬量介于0.5和1.0英⼨之间的天数是29天。
Note that the parentheses here are important–because of operator precedence rules, with parentheses removed this
expression would be evaluated as follows, which results in an error:
注意上⾯例⼦中两个⽐较运算的括号是必不可少的,因为运算符顺序规定,位运算优于⽐较运算,因此,如果省略括号,我们会得到下⾯
语句⼀样的结果,显然是错误的:
inches > (0.5 & inches) < 1
Using the equivalence of A AND B and NOT (NOT A OR NOT B) (which you may remember if you've taken an
introductory logic course), we can compute the same result in a different manner:
下⾯的例⼦使⽤了⼀种等同的语法来得到相同的结果,这种写法基于逻辑算术的基本知识:A 且 B 和 ⾮(⾮A 或 ⾮B)是相等的:
In [24]: np.sum(~( (inches <= 0.5) | (inches >= 1) ))
Out[24]: 29
Combining comparison operators and Boolean operators on arrays can lead to a wide range of efficient logical
operations.
结合⽐较运算和布尔运算就可以获得在数组上进⾏绝⼤部分逻辑运算的能⼒。
The following table summarizes the bitwise Boolean operators and their equivalent ufuncs:
下表列出了布尔运算符及其对应ufuncs:
运算符
相应的ufunc 运算符
相应的ufunc
|
np.bitwise_or
&
np.bitwise_and
^
np.bitwise_xor
~
np.bitwise_not
Using these tools, we might start to answer the types of questions we have about our weather data. Here are some
examples of results we can compute when combining masking with aggregations:
使⽤这些⼯具,我们可以回头来解答前⾯例⼦中关于⾬量的四个问题。下⾯的代码就是我们结合遮盖和聚合之后得到的问题的答案:
⽆⾬的天数
:", np.sum(inches == 0))
有⾬的天数
:", np.sum(inches != 0))
⾬量⼤于0.5英⼨的天数 :", np.sum(inches > 0.5))
⾬量⼩于0.2英⼨的有⾬天数:", np.sum((inches > 0) & (inches < 0.2)))
⽆⾬的天数
: 215
有⾬的天数
: 150
⾬量⼤于0.5英⼨的天数 : 37
⾬量⼩于0.2英⼨的有⾬天数: 75
In [27]: print("
print("
print("
print("
Boolean Arrays as Masks
使⽤布尔数组作为遮盖
In the preceding section we looked at aggregates computed directly on Boolean arrays. A more powerful pattern is to use
Boolean arrays as masks, to select particular subsets of the data themselves. Returning to our x array from before,
suppose we want an array of all values in the array that are less than, say, 5:
在刚才的例⼦中,我们在布尔数组上应⽤聚合操作,得到结果。⼀个更加有⽤的场景是使⽤布尔数组作为遮盖,⽤来从数据集中选择⽬标
数据出来。回到前⾯数组 x 的例⼦,如果我们要选择数组中所有⼩于5的元素,可以这样做:
In [28]: x
Out[28]: array([[5, 0, 3, 3],
[7, 9, 3, 5],
[2, 4, 7, 6]])
We can obtain a Boolean array for this condition easily, as we've already seen:
使⽤下⾯的⽐较运算很容易得到⼀个布尔数组,指代每个元素是否⼩于5:
In [29]: x < 5
Out[29]: array([[False, True, True, True],
[False, False, True, False],
[ True, True, False, False]])
Now to select these values from the array, we can simply index on this Boolean array; this is known as a masking
operation:
下⾯我们来从数组中选择符合条件的值出来,我们可以将上⾯得到的布尔数组作为索引带⼊数组中,成为遮盖操作:
In [30]: x[x < 5]
Out[30]: array([0, 3, 3, 3, 2, 4])
What is returned is a one-dimensional array filled with all the values that meet this condition; in other words, all the values
in positions at which the mask array is True .
返回的是⼀个⼀维数组,⾥⾯的每个元素都满⾜条件:那就是结果数组中出现的元素对应的是遮盖布尔数组相应位置上为 True 真值。
We are then free to operate on these values as we wish. For example, we can compute some relevant statistics on our
Seattle rain data:
然后就可以灵活应⽤遮盖⽅法来获得我们需要的值了。例如,下⾯例⼦计算了很多西雅图⾬量数据集相关的统计值:
下⾬天的遮盖数组
In [31]: #
rainy = (inches > 0)
夏天的遮盖数组 ⽉ ⽇是⼀年的第 天
#
(6 21
172 )
days = np.arange(365)
summer = (days > 172) & (days < 262)
年下⾬天⾬量中位数(英⼨):", np.median(inches[rainy]))
年夏天⾬量中位数(英⼨):", np.median(inches[summer]))
年夏天⾬量最⼤值(英⼨):",np.max(inches[summer]))
除夏季外其他下⾬天⾬量中位数(英⼨):", np.median(inches[rainy & ~summer]))
2014年下⾬天⾬量中位数(英⼨): 0.19488188976377951
2014年夏天⾬量中位数(英⼨): 0.0
2014年夏天⾬量最⼤值(英⼨): 0.8503937007874016
除夏季外其他下⾬天⾬量中位数(英⼨): 0.20078740157480315
print("2014
print("2014
print("2014
print("
By combining Boolean operations, masking operations, and aggregates, we can very quickly answer these sorts of
questions for our dataset.
结合布尔操作、遮盖操作和聚合操作,我们可以很快在数据集中得到这类问题的答案。
Aside: Using the Keywords and/or Versus the Operators &/|
附加内容:对⽐使⽤and/or关键字和&/|运算符
One common point of confusion is the difference between the keywords and and or on one hand, and the operators
& and | on the other hand. When would you use one versus the other?
使⽤关键字 and 和 or ,与使⽤运算符 & 和 | ,两者的区别,常常会困惑很多⼈。什么情况下你应该⽤哪种运算呢?
The difference is this: and and or gauge the truth or falsehood of entire object, while & and | refer to bits within
each object.
区别在于: and 和 or ⽤在将整个对象当成真值或假值进⾏运算的场合,⽽ & 和 | 会针对每个对象内的⼆进制位进⾏运算。
When you use and or or , it's equivalent to asking Python to treat the object as a single Boolean entity. In Python, all
nonzero integers will evaluate as True. Thus:
当你使⽤ and 或 or 的时候,相当于要求Python将对象当成是⼀个布尔值的整体。在Python中,所有的⾮0值都会被演算成True,因此:
In [32]: bool(42), bool(0)
Out[32]: (True, False)
In [33]: bool(42 and 0)
Out[33]: False
In [34]: bool(42 or 0)
Out[34]: True
When you use & and | on integers, the expression operates on the bits of the element, applying the and or the or to
the individual bits making up the number:
当你在整数上使⽤ & 和 | 运算时,这两个操作会运算整数中的每个⼆进制位,在每个⼆进制位上执⾏⼆进制与或⼆进制或操作:
In [35]: bin(42)
Out[35]: '0b101010'
In [36]: bin(59)
Out[36]: '0b111011'
In [37]: bin(42 & 59)
Out[37]: '0b101010'
In [38]: bin(42 | 59)
Out[38]: '0b111011'
Notice that the corresponding bits of the binary representation are compared in order to yield the result.
对⽐⼀下上⾯例⼦中的结果是如何从操作数上进⾏⼆进制运算获得的。
When you have an array of Boolean values in NumPy, this can be thought of as a string of bits where 1 = True and 0
= False , and the result of & and | operates similarly to above:
当数组是⼀个NumPy的布尔数组时,你可以将这个布尔数组想象成它是由⼀系列⼆进制位组成的,因为 1 = True 和 0 = False ,所
以使⽤ & 和 | 运算得到的结果类似上⾯的例⼦:
In [39]: A = np.array([1, 0, 1, 0, 1, 0], dtype=bool)
B = np.array([1, 1, 1, 0, 1, 1], dtype=bool)
A | B
Out[39]: array([ True,
True,
True, False,
True,
True])
Using or on these arrays will try to evaluate the truth or falsehood of the entire array object, which is not a well-defined
value:
在数组间使⽤ or 操作时,等同于要求Python把数组当成⼀个整体来求出最终的真值或假值,这样的值是不存在的,因此会导致⼀个错
误:
In [40]: A or B
--------------------------------------------------------------------------ValueError
Traceback (most recent call last)
<ipython-input-40-ea2c97d9d9ee> in <module>
----> 1 A or B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all
()
Similarly, when doing a Boolean expression on a given array, you should use | or & rather than or or and :
类似的,当对于给定的数组进⾏布尔表达式运算时,你应该使⽤ | 或 & ,⽽不是 or 或 and :
In [41]: x = np.arange(10)
(x > 4) & (x < 8)
Out[41]: array([False, False, False, False, False,
False])
True,
True,
True, False,
Trying to evaluate the truth or falsehood of the entire array will give the same ValueError we saw previously:
同样如果试图把数组当成⼀个整体计算最终真值或假值也是不被允许的,结果还是我们前⾯看到的那个 ValueError :
In [42]: (x > 4) and (x < 8)
--------------------------------------------------------------------------ValueError
Traceback (most recent call last)
<ipython-input-42-eecf1fdd5fb4> in <module>
----> 1 (x > 4) and (x < 8)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all
()
So remember this: and and or perform a single Boolean evaluation on an entire object, while & and | perform
multiple Boolean evaluations on the content (the individual bits or bytes) of an object. For Boolean NumPy arrays, the
latter is nearly always the desired operation.
因此,你只需要记住: and 和 or 对整个对象进⾏单个布尔操作,⽽ & 和 | 会对⼀个对象进⾏多个布尔操作(⽐如其中每个⼆进制
位)。对于NumPy布尔数组来说,需要的总是后两者。
<
在数组上计算:⼴播 | ⽬录 | ⾼级索引 >
Open in Colab
<
⽐较,遮盖和布尔逻辑 | ⽬录 | 数组排序 >
Open in Colab
Fancy Indexing
⾼级索引
In the previous sections, we saw how to access and modify portions of arrays using simple indices (e.g., arr[0] ),
slices (e.g., arr[:5] ), and Boolean masks (e.g., arr[arr > 0] ). In this section, we'll look at another style of array
indexing, known as fancy indexing. Fancy indexing is like the simple indexing we've already seen, but we pass arrays of
indices in place of single scalars. This allows us to very quickly access and modify complicated subsets of an array's
values.
在前⾯的⼩节中,我们学习了如何获取和修改数组的元素或部分元素,我们可以通过简单索引(例如 arr[0] ),切⽚(例如
arr[:5] )和布尔遮盖(例如 arr[arr > 0] )来实现。本节来介绍另外⼀种数组索引的⽅式,被称为⾼级索引。⾼级索引语法上和
前⾯我们学习到的简单索引很像,区别只是它不是传递标量参数作为索引值,⽽是传递数组参数作为索引值。它能让我们很迅速的获取和
修改复杂数组或⼦数组的元素值。
Exploring Fancy Indexing
初探⾼级索引
Fancy indexing is conceptually simple: it means passing an array of indices to access multiple array elements at once.
For example, consider the following array:
⾼级索引在概念层⾯⾮常简单:传递⼀个数组作为索引值参数,使得⽤⼾能⼀次性的获取或修改多个数组元素值。例如下⾯的数组:
In [1]: import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
Suppose we want to access three different elements. We could do it like this:
假如我们需要访问其中三个不同的元素。我们可以这样做:
In [2]: [x[3], x[7], x[2]]
Out[2]: [71, 86, 14]
Alternatively, we can pass a single list or array of indices to obtain the same result:
还有⼀种⽅法,我们以⼀个数组的⽅式将这些元素的索引传递给数组,也可以获得相同的结果:
In [3]: ind = [3, 7, 4]
x[ind]
Out[3]: array([71, 86, 60])
When using fancy indexing, the shape of the result reflects the shape of the index arrays rather than the shape of the
array being indexed:
当使⽤⾼级索引时,结果数组的形状取决于索引数组的形状⽽不是被索引数组的形状:
In [4]: ind = np.array([[3, 7],
[4, 5]]) #
x[ind]
索引数组是⼀个2x2数组,结果也将会是⼀个2x2数组
Out[4]: array([[71, 86],
[60, 20]])
Fancy indexing also works in multiple dimensions. Consider the following array:
⾼级索引也⽀持多维数组。例如:
In [5]: X = np.arange(12).reshape((3, 4))
X
Out[5]: array([[ 0,
[ 4,
[ 8,
1, 2, 3],
5, 6, 7],
9, 10, 11]])
Like with standard indexing, the first index refers to the row, and the second to the column:
就像普通索引⼀样,第⼀个参数代表⾏,第⼆个参数代表列:
In [6]: row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
Out[6]: array([ 2,
5, 11])
Notice that the first value in the result is X[0, 2] , the second is X[1, 1] , and the third is X[2, 3] . The pairing of
indices in fancy indexing follows all the broadcasting rules that were mentioned in Computation on Arrays: Broadcasting.
So, for example, if we combine a column vector and a row vector within the indices, we get a two-dimensional result:
结果中的第⼀个值是 x[0, 2] ,第⼆个值是 x[1, 1] ,第三个值是 x[2, 3] 。⾼级索引的多个维度组合⽅式也遵守⼴播的规则,请
查阅在数组上计算:⼴播。因此,如果我们在上⾯的⾏索引数组中增加⼀个维度,结果将变成⼀个⼆维数组:
In [7]: X[row[:, np.newaxis], col]
Out[7]: array([[ 2,
[ 6,
[10,
1, 3],
5, 7],
9, 11]])
Here, each row value is matched with each column vector, exactly as we saw in broadcasting of arithmetic operations.
For example:
这⾥,每个⾏索引都会匹配每个列的向量,就像我们在⼴播的算术运算中看到⼀样。例如:
In [8]: row[:, np.newaxis] * col
Out[8]: array([[0, 0, 0],
[2, 1, 3],
[4, 2, 6]])
It is always important to remember with fancy indexing that the return value reflects the broadcasted shape of the indices,
rather than the shape of the array being indexed.
记住⾼级索引结果的形状是索引数组⼴播后的形状⽽不是被索引数组形状,这点⾮常重要。
Combined Indexing
组合索引
For even more powerful operations, fancy indexing can be combined with the other indexing schemes we've seen:
结合我们前⾯学习过的索引⽅法,我们可以组合出更多更强⼤的操作:
In [9]: print(X)
[[ 0
[ 4
[ 8
1 2 3]
5 6 7]
9 10 11]]
We can combine fancy and simple indices:
我们可以将⾼级索引和简单索引进⾏组合:
译者注,实际上这就是个⼴播,将标量⼴播成⼀个向量。
In [10]: X[2, [2, 0, 1]]
Out[10]: array([10,
8,
9])
We can also combine fancy indexing with slicing:
我们也可以将⾼级索引和切⽚进⾏组合:
In [11]: X[1:, [2, 0, 1]]
Out[11]: array([[ 6,
[10,
4,
8,
5],
9]])
And we can combine fancy indexing with masking:
还可以将⾼级索引和遮盖进⾏组合:
In [12]: mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask]
Out[12]: array([[ 0, 2],
[ 4, 6],
[ 8, 10]])
All of these indexing options combined lead to a very flexible set of operations for accessing and modifying array values.
所有这些索引操作可以提供⽤⼾⾮常灵活的⽅式来获取和修改数组中的数据。
Example: Selecting Random Points
例⼦:选择随机点
One common use of fancy indexing is the selection of subsets of rows from a matrix. For example, we might have an N
by D matrix representing N points in D dimensions, such as the following points drawn from a two-dimensional normal
distribution:
⾼级索引的⼀个通⽤应⽤场景就是从⼀个矩阵的⾏中选取⼦数据集。例如,我们有⼀个N × D的矩阵,代表着⼀个D维平⾯上有N个点,例
如下⾯的⼆维正态分布的点集合:
In [13]: mean = [0, 0]
cov = [[1, 2],
[2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
Out[13]: (100, 2)
Using the plotting tools we will discuss in Introduction to Matplotlib, we can visualize these points as a scatter-plot:
使⽤我们会在第四章详细介绍的Matplotlib⼯具,我们可以在散点图上绘制这些点:
In [14]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() #
设置图表⻛格,seaborn
plt.scatter(X[:, 0], X[:, 1]);
Let's use fancy indexing to select 20 random points. We'll do this by first choosing 20 random indices with no repeats,
and use these indices to select a portion of the original array:
下⾯我们使⽤⾼级索引来选择20个随机点。⽅法是先创建⼀个索引数组,⾥⾯的索引值是没有重复的,然后使⽤这个索引数组来选择点:
In [15]: indices = np.random.choice(X.shape[0], 20, replace=False)
indices
Out[15]: array([42, 64, 11, 76, 46, 33, 77, 14, 91, 20, 13,
55, 40, 61])
In [16]: selection = X[indices]
selection.shape
#
4, 60, 49,
0, 32, 21,
使⽤⾼级索引
Out[16]: (20, 2)
Now to see which points were selected, let's over-plot large circles at the locations of the selected points:
下⾯我们来看看那些点被选中,让我们上图的基础上将选中的点圈出来:
In [17]: plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1],
facecolor='none', s=200);
This sort of strategy is often used to quickly partition datasets, as is often needed in train/test splitting for validation of
statistical models (see Hyperparameters and Model Validation), and in sampling approaches to answering statistical
questions.
这种策略经常⽤来划分数据集,⽐如⽤来验证统计模型正确性时需要的训练集和测试集划分(参⻅超参数及模型验证),还有就是在回答
统计问题时进⾏取样抽象。
Modifying Values with Fancy Indexing
使⽤⾼级索引修改数据
Just as fancy indexing can be used to access parts of an array, it can also be used to modify parts of an array. For
example, imagine we have an array of indices and we'd like to set the corresponding items in an array to some value:
前⾯我们看到⾼级索引能够被⽤来获取⼀个数组的部分数据,实际上它还能⽤来修改选中部分的数据。例如,我们⼿头有⼀个索引的数
组,我们想将这些索引上的数据修改为某个值:
In [18]: x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
print(x)
[ 0 99 99
3 99
5
6
7 99
9]
We can use any assignment-type operator for this. For example:
我们可以使⽤任何赋值类型操作,例如:
In [19]: x[i] -= 10
print(x)
[ 0 89 89
3 89
5
6
7 89
9]
Notice, though, that repeated indices with these operations can cause some potentially unexpected results. Consider the
following:
请注意下,如果索引数组中有重复的元素的话,这种修改操作可能会导致⼀个潜在的意料之外的结果。例如:
In [20]: x = np.zeros(10)
x[[0, 0]] = [4, 6]
print(x)
[6. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Where did the 4 go? The result of this operation is to first assign x[0] = 4 , followed by x[0] = 6 . The result, of
course, is that x[0] contains the value 6.
4
跑到哪⾥去了呢?这个操作⾸先赋值 x[0] = 4 ,然后赋值 x[0] = 6 ,因此最后 x[0] 的值是6。
Fair enough, but consider this operation:
上⾯的例⼦还算⽐较清晰,再看下⾯这个操作:
In [21]: i = [2, 3, 3, 4, 4, 4]
x[i] += 1
x
Out[21]: array([6., 0., 1., 1., 1., 0., 0., 0., 0., 0.])
You might expect that x[3] would contain the value 2, and x[4] would contain the value 3, as this is how many times
each index is repeated. Why is this not the case? Conceptually, this is because x[i] += 1 is meant as a shorthand of
x[i] = x[i] + 1 . x[i] + 1 is evaluated, and then the result is assigned to the indices in x. With this in mind, it is
not the augmentation that happens multiple times, but the assignment, which leads to the rather nonintuitive results.
So what if you want the other behavior where the operation is repeated? For this, you can use the at() method of
ufuncs (available since NumPy 1.8), and do the following:
我们期望的结果可能是 x[3] 的值是2,⽽ x[4] 的值是3,因为这两个元素都多次执⾏了加法操作。但是为何结果不是呢?这是因为
x[i] += 1 是操作 x[i] = x[i] + 1 的简写,⽽ x[i] + 1 表达式的值已经计算好了,然后才被赋值给 x[i] 。因此,上⾯的操作
不会被扩展为重复的运算,⽽是⼀次的赋值操作,造成了这种难以理解的结果。
如果我们真的需要这种重复的操作怎么办?对此,NumPy(版本1.8以上)提供了 at() ufunc⽅法可以满⾜这个⽬的,如下:
In [22]: x = np.zeros(10)
np.add.at(x, i, 1)
print(x)
[0. 0. 1. 2. 3. 0. 0. 0. 0. 0.]
The at() method does an in-place application of the given operator at the specified indices (here, i ) with the
specified value (here, 1). Another method that is similar in spirit is the reduceat() method of ufuncs, which you can
read about in the NumPy documentation.
⽅法不会预先计算表达式的值,⽽是每次运算时实时得到,⽅法在⼀个数组 x 中取得特定索引 i ,然后将其取得的值与最后⼀个
参数 进⾏相应计算,这⾥是加法 add 。还有⼀个类似的⽅法是 reduceat() ⽅法,你可以从NumPy的⽂档中阅读它的说明。
at()
1
Example: Binning Data
例⼦:数据分组
You can use these ideas to efficiently bin data to create a histogram by hand. For example, imagine we have 1,000
values and would like to quickly find where they fall within an array of bins. We could compute it using ufunc.at like
this:
你可以使⽤上⾯的⽅法对数据进⾏⾼效分组,⽤于定义⾃⼰的直⽅图。例如,设想我们有1000个值,我们想将它们分别放⼊各个不同的数
组分组中。我们可以使⽤ at 函数,例如:
In [23]: np.random.seed(42)
x = np.random.randn(100) #
获得⼀个⼀维100个标准正态分布值
# 得到⼀个⾃定义的数据分组,区间-5⾄5平均取20个点,每个区间为⼀个数据分组
bins = np.linspace(-5, 5, 20)
counts = np.zeros_like(bins) # counts是x数值落⼊区间的计数
# 使⽤searchsorted,得到x每个元素在bins中落⼊的区间序号
i = np.searchsorted(bins, x)
使⽤ 和 ,对 元素在每个区间的元素个数进⾏计算
#
at add
x
np.add.at(counts, i, 1)
The counts now reflect the number of points within each bin–in other words, a histogram:
现在包含着每个数据分组中元素的个数,换句话来说,就是直⽅图:
译者注:Matplotlib 3.1开始,linestyle关键字参数已经过时,后续版本会抛弃。下⾯代码依据最新参数更改为drawstyle或ds。
counts
⽤图表展⽰结果
In [24]: #
plt.plot(bins, counts, ds='steps');
Of course, it would be silly to have to do this each time you want to plot a histogram. This is why Matplotlib provides the
plt.hist() routine, which does the same in a single line:
当然,如果每次要画直⽅图的时候,都要经过这么复杂的计算,很不⽅便。这也就是为什么Matplotlib提供了 plt.hist() ⽅法的原因,
可以⽤⼀⾏代码完成上⾯操作:
plt.hist(x, bins, histtype='step');
This function will create a nearly identical plot to the one seen here. To compute the binning, matplotlib uses the
np.histogram function, which does a very similar computation to what we did before. Let's compare the two here:
这个函数会创建⼀个和上图基本完全⼀样的图形。Matplotlib使⽤ np.histogram 函数来计算数据分组,这个函数进⾏的计算和我们上⾯
的代码⾮常接近。我们⽐较⼀下这两个⽅法:
In [26]: print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine:
22.7 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Custom routine:
12.1 µs ± 426 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Our own one-line algorithm is several times faster than the optimized algorithm in NumPy! How can this be? If you dig
into the np.histogram source code (you can do this in IPython by typing np.histogram?? ), you'll see that it's
quite a bit more involved than the simple search-and-count that we've done; this is because NumPy's algorithm is more
flexible, and particularly is designed for better performance when the number of data points becomes large:
我们⾃⼰写的⼀⾏代码⽐NumPy优化的算法要快出许多,这是因为什么?如果你深⼊到 np.histogram 函数的源代码进⾏阅读(你可以
通过在IPython中输⼊ np.histogram?? 来查阅)的时候,你会发现函数除了搜索和计数之外,还做了其他很多⼯作;这是因为NumPy
的函数要更加灵活,⽽且当数据量变⼤的时候能够提供更好的性能:
In [27]: x = np.random.randn(1000000)
print("NumPy routine:")
%timeit counts, edges = np.histogram(x, bins)
print("Custom routine:")
%timeit np.add.at(counts, np.searchsorted(bins, x), 1)
NumPy routine:
67.2 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Custom routine:
90.4 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
What this comparison shows is that algorithmic efficiency is almost never a simple question. An algorithm efficient for
large datasets will not always be the best choice for small datasets, and vice versa (see Big-O Notation). But the
advantage of coding this algorithm yourself is that with an understanding of these basic methods, you could use these
building blocks to extend this to do some very interesting custom behaviors. The key to efficiently using Python in dataintensive applications is knowing about general convenience routines like np.histogram and when they're
appropriate, but also knowing how to make use of lower-level functionality when you need more pointed behavior.
上⾯的结果说明当涉及到算法的性能时,永远不可能是⼀个简单的问题。对于⼤数据集来说⼀个很⾼效的算法,并不⼀定也适⽤于⼩数据
集,反之亦然(参⻅⼤O复杂度)。我们这⾥使⽤⾃⼰的代码实现这个算法,⽬的是理解上⾯的基本函数,后续读者可以使⽤这些函数构
建⾃⼰定义的各种功能。在数据科学应⽤中使⽤Python编写代码的关键在于,你能掌握NumPy提供的很⽅便的函数如 np.histogram ,
你也能知道什么情况下适合使⽤它们,当需要更加定制的功能时你还能使⽤底层的函数⾃⼰实现相应的算法。
<
⽐较,遮盖和布尔逻辑 | ⽬录 | 数组排序 >
Open in Colab
Loading [MathJax]/jax/output/HTML-CSS/fonts/TeX/fontdata.js
<
⾼级索引 | ⽬录 | 格式化数据:NumPy⾥的结构化数组 >
Open in Colab
Sorting Arrays
数组排序
Up to this point we have been concerned mainly with tools to access and operate on array data with NumPy. This section
covers algorithms related to sorting values in NumPy arrays. These algorithms are a favorite topic in introductory
computer science courses: if you've ever taken one, you probably have had dreams (or, depending on your
temperament, nightmares) about insertion sorts, selection sorts, merge sorts, quick sorts, bubble sorts, and many, many
more. All are means of accomplishing a similar task: sorting the values in a list or array.
本节之前,我们主要关注NumPy中那些获取和操作数组数据的⼯具。本⼩节我们会介绍对NumPy数组进⾏排序的算法。这些算法在基础计
算机科学领域是很热⻔的课题:如果你学习过相关的课程的话,你可能梦(或者根据你的经理,可能是噩梦)到过有关插⼊排序、选择排
序、归并排序、快速排序、冒泡排序和其他很多很多名词。这些都是为了完成⼀件⼯作的:对数组进⾏排序。
For example, a simple selection sort repeatedly finds the minimum value from a list, and makes swaps until the list is
sorted. We can code this in just a few lines of Python:
例如,⼀个简单的选择排序会重复寻找列表中最⼩的值,然后和当前值进⾏交换,直到列表排序完成。我们可以在Python中⽤简单的⼏⾏
代码完成这个算法:
In [1]: import numpy as np
def selection_sort(x):
for i in range(len(x)):
swap = i + np.argmin(x[i:]) #
(x[i], x[swap]) = (x[swap], x[i]) #
return x
寻找⼦数组中的最⼩值的索引序号
交换当前值和最⼩值
In [2]: x = np.array([2, 1, 4, 3, 5])
selection_sort(x)
Out[2]: array([1, 2, 3, 4, 5])
As any first-year computer science major will tell you, the selection sort is useful for its simplicity, but is much too slow to
be useful for larger arrays. For a list of N values, it requires N loops, each of which does on order ∼ N comparisons to
find the swap value. In terms of the "big-O" notation often used to characterize these algorithms (see Big-O Notation),
selection sort averages O[N 2 ] : if you double the number of items in the list, the execution time will go up by about a
factor of four.
任何⼀个5年的计算机科学专业都会教你,选择排序很简单,但是对于⼤的数组来说运⾏效率就不够了。对于数组具有 个值,它需要
次循环,每次循环中需要 次⽐较和寻找来交换元素。⼤O表⽰法经常⽤来对算法性能进⾏定量分析(参⻅⼤O复杂度),选择排序
平均需要
:如果列表中的元素个数加倍,执⾏时间增⻓⼤约是原来的4倍。
N
N
∼ N
O[N
2
]
Even selection sort, though, is much better than my all-time favorite sorting algorithms, the bogosort:
甚⾄选择排序也远⽐下⾯这个bogo排序算法有效地多,这是作者最喜爱的排序算法:
In [3]: def bogosort(x):
while np.any(x[:-1] > x[1:]):
np.random.shuffle(x)
return x
In [4]: x = np.array([2, 1, 4, 3, 5])
bogosort(x)
Out[4]: array([1, 2, 3, 4, 5])
This silly sorting method relies on pure chance: it repeatedly applies a random shuffling of the array until the result
happens to be sorted. With an average scaling of O[N × N !] , (that's N times N factorial) this should–quite obviously–
never be used for any real computation.
这个有趣⽽粗苯的算法完全依赖于概率:它重复的对数组进⾏随机的乱序直到结果刚好是正确排序为⽌。这个算法平均需要
,即N乘以N的阶乘,明显的,在真实情况下,它不应该被⽤于排序计算。
O[N × N !]
Fortunately, Python contains built-in sorting algorithms that are much more efficient than either of the simplistic algorithms
just shown. We'll start by looking at the Python built-ins, and then take a look at the routines included in NumPy and
optimized for NumPy arrays.
幸运的是,Python內建有了排序算法,⽐我们刚才提到那些简单的算法都要⾼效。我们从Python內建的排序开始介绍,然后再去讨论
NumPy中为了数组优化的排序函数。
Fast Sorting in NumPy: np.sort and np.argsort
中快速排序: np.sort 和 np.argsort
NumPy
Although Python has built-in sort and sorted functions to work with lists, we won't discuss them here because
NumPy's np.sort function turns out to be much more efficient and useful for our purposes. By default np.sort uses
an O[N log N ], quicksort algorithm, though mergesort and heapsort are also available. For most applications, the
default quicksort is more than sufficient.
虽然Python有內建的 sort 和 sorted 函数可以⽤来对列表进⾏排序,我们在这⾥不讨论它们。因为NumPy的 np.sort 函数有着更加
优秀的性能,⽽且也更满⾜我们要求。默认情况下 np.sort 使⽤的是
快速排序排序算法,归并排序和堆排序也是可选的。
对于⼤多数的应⽤场景来说,默认的快速排序都能满⾜要求。
O[N log N ]
To return a sorted version of the array without modifying the input, you can use np.sort :
对数组进⾏排序,返回排序后的结果,不改变原始数组的数据,你应该使⽤ np.sort :
In [5]: x = np.array([2, 1, 4, 3, 5])
np.sort(x)
Out[5]: array([1, 2, 3, 4, 5])
If you prefer to sort the array in-place, you can instead use the sort method of arrays:
如果你期望直接改变数组的数据进⾏排序,你可以对数组对象使⽤它的 sort ⽅法:
In [6]: x.sort()
print(x)
[1 2 3 4 5]
A related function is argsort , which instead returns the indices of the sorted elements:
相关的函数是 argsort ,它将返回排好序后元素原始的序号序列:
In [7]: x = np.array([2, 1, 4, 3, 5])
i = np.argsort(x)
print(i)
[1 0 3 2 4]
The first element of this result gives the index of the smallest element, the second value gives the index of the second
smallest, and so on. These indices can then be used (via fancy indexing) to construct the sorted array if desired:
结果的第⼀个元素是数组中最⼩元素的序号,第⼆个元素是数组中第⼆⼩元素的序号,以此类推。这些序号可以通过⾼级索引的⽅式使
⽤,从⽽获得⼀个排好序的数组:
译者注:更好的问题应该是,假如我们希望获得数组中第⼆、三⼩的元素,我们可以这样做:
x[i[1:3]]
In [8]: x[i]
Out[8]: array([1, 2, 3, 4, 5])
Sorting along rows or columns
按照⾏或列进⾏排序
A useful feature of NumPy's sorting algorithms is the ability to sort along specific rows or columns of a multidimensional
array using the axis argument. For example:
的排序算法可以沿着多维数组的某些轴 axis 进⾏,如⾏或者列。例如:
NumPy
In [9]: rand = np.random.RandomState(42)
X = rand.randint(0, 10, (4, 6))
print(X)
[[6 3 7 4 6 9]
[2 6 7 4 3 7]
[7 2 5 4 1 7]
[5 1 4 0 9 5]]
沿着每列对数据进⾏排序
In [10]: #
np.sort(X, axis=0)
Out[10]: array([[2, 1, 4, 0, 1, 5],
[5, 2, 5, 4, 3, 7],
[6, 3, 7, 4, 6, 7],
[7, 6, 7, 4, 9, 9]])
沿着每⾏对数据进⾏排序
In [11]: #
np.sort(X, axis=1)
Out[11]: array([[3, 4, 6, 6, 7, 9],
[2, 3, 4, 6, 7, 7],
[1, 2, 4, 5, 7, 7],
[0, 1, 4, 5, 5, 9]])
Keep in mind that this treats each row or column as an independent array, and any relationships between the row or
column values will be lost!
必须注意的是,这样的排序会独⽴的对每⼀⾏或者每⼀列进⾏排序。因此结果中原来⾏或列之间的联系都会丢失。
Partial Sorts: Partitioning
部分排序:分区
Sometimes we're not interested in sorting the entire array, but simply want to find the k smallest values in the array.
NumPy provides this in the np.partition function. np.partition takes an array and a number K; the result is a
new array with the smallest K values to the left of the partition, and the remaining values to the right, in arbitrary order:
有时候我们并不是需要对整个数组排序,⽽仅仅需要找到数组中的K个最⼩值。NumPy提供了 np.partition 函数来完成这个任务;结
果会分为两部分,最⼩的K个值位于结果数组的左边,⽽其余的值位于数组的右边,顺序随机:
In [12]: x = np.array([7, 2, 3, 1, 6, 5, 4])
np.partition(x, 3)
Out[12]: array([2, 1, 3, 4, 6, 5, 7])
Note that the first three values in the resulting array are the three smallest in the array, and the remaining array positions
contain the remaining values. Within the two partitions, the elements have arbitrary order.
你可以看到结果中最⼩的三个值在左边,其余4个值位于数组的右边,每个分区内部,元素的顺序是任意的。
Similarly to sorting, we can partition along an arbitrary axis of a multidimensional array:
和排序⼀样,我们可以按照任意维度对⼀个多维数组进⾏分区:
In [13]: np.partition(X, 2, axis=1)
Out[13]: array([[3, 4, 6, 7, 6, 9],
[2, 3, 4, 7, 6, 7],
[1, 2, 4, 5, 7, 7],
[0, 1, 4, 5, 9, 5]])
The result is an array where the first two slots in each row contain the smallest values from that row, with the remaining
values filling the remaining slots.
结果中每⾏的前两个元素就是该⾏最⼩的两个值,该⾏其余的值会出现在后⾯。
Finally, just as there is a np.argsort that computes indices of the sort, there is a np.argpartition that computes
indices of the partition. We'll see this in action in the following section.
最后,就像 np.argsort 函数可以返回排好序的元素序号⼀样, np.argpartition 可以计算分区后元素的序号。后⾯的例⼦中我们会
看到它的使⽤。
Example: k-Nearest Neighbors
例⼦:k近邻
Let's quickly see how we might use this argsort function along multiple axes to find the nearest neighbors of each
point in a set. We'll start by creating a random set of 10 points on a two-dimensional plane. Using the standard
convention, we'll arrange these in a 10 × 2 array:
下⾯我们使⽤ argsort 沿着多个维度来寻找每个点的最近邻。⾸先在⼀个⼆维平⾯上创建10个随机点数据。按照管理,这将是⼀个
的数组:
10 × 2
In [14]: X = rand.rand(10, 2)
To get an idea of how these points look, let's quickly scatter plot them:
我们先来观察⼀下这些点的分布情况,散点图很适合这种情形:
In [15]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() #
plt.scatter(X[:, 0], X[:, 1], s=100);
图表⻛格,seaborn
Now we'll compute the distance between each pair of points. Recall that the squared-distance between two points is the
sum of the squared differences in each dimension; using the efficient broadcasting (Computation on Arrays:
Broadcasting) and aggregation (Aggregations: Min, Max, and Everything In Between) routines provided by NumPy we
can compute the matrix of square distances in a single line of code:
现在让我们来计算每两个点之间的距离。距离平⽅的定义是两点坐标差的平⽅和。应⽤⼴播(在数组上计算:⼴播)和聚合(聚合:Min,
Max, 以及其他)函数,我们可以使⽤⼀⾏代码就能计算出所有点之间的距离平⽅:
In [16]: dist_sq = np.sum((X[:, np.newaxis, :] - X[np.newaxis, :, :]) ** 2, axis=-1)
This operation has a lot packed into it, and it might be a bit confusing if you're unfamiliar with NumPy's broadcasting
rules. When you come across code like this, it can be useful to break it down into its component steps:
上⾯的这⾏代码包含很多的内容值得探讨,如果对于不是特别熟悉⼴播机制的读者来说,看起来可能会让⼈难以理解。当你读到这样的代
码的时候,将它们打散成⼀步步的操作会有帮助:
In [17]: # 计算每两个点之间的坐标距离
differences = X[:, np.newaxis, :] - X[np.newaxis, :, :]
differences.shape
Out[17]: (10, 10, 2)
计算距离的平⽅
In [18]: #
sq_differences = differences ** 2
sq_differences.shape
Out[18]: (10, 10, 2)
按照最后⼀个维度求和
In [19]: #
dist_sq = sq_differences.sum(-1)
dist_sq.shape
Out[19]: (10, 10)
Just to double-check what we are doing, we should see that the diagonal of this matrix (i.e., the set of distances between
each point and itself) is all zero:
你可以检查这个矩阵的对⻆线元素,对⻆线元素的值是点与其⾃⾝的距离平⽅,应该全部为0:
In [20]: dist_sq.diagonal()
Out[20]: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
It checks out! With the pairwise square-distances converted, we can now use np.argsort to sort along each row. The
leftmost columns will then give the indices of the nearest neighbors:
确认正确。现在我们已经有了⼀个距离平⽅的矩阵,然后就可以使⽤ np.argsort 函数来按照每⾏来排序。最左边的列就会给出每个点
的最近邻:
In [21]: nearest = np.argsort(dist_sq, axis=1)
print(nearest)
[[0 3 9 7 1 4 2 5 6 8]
[1 4 7 9 3 6 8 5 0 2]
[2 1 4 6 3 0 8 9 7 5]
[3 9 7 0 1 4 5 8 6 2]
[4 1 8 5 6 7 9 3 0 2]
[5 8 6 4 1 7 9 3 2 0]
[6 8 5 4 1 7 9 3 2 0]
[7 9 3 1 4 0 5 8 6 2]
[8 5 6 4 1 7 9 3 2 0]
[9 7 3 0 1 4 5 8 6 2]]
Notice that the first column gives the numbers 0 through 9 in order: this is due to the fact that each point's closest
neighbor is itself, as we would expect.
结果中的第⼀列是0到9的数字:这是因为距离每个点最近的是⾃⼰,正如我们预料的⼀样。
By using a full sort here, we've actually done more work than we need to in this case. If we're simply interested in the
nearest k neighbors, all we need is to partition each row so that the smallest k + 1 squared distances come first, with
larger distances filling the remaining positions of the array. We can do this with the np.argpartition function:
上⾯我们进⾏了完整的排序,事实上我们并不需要这么做。如果我们只是对最近的 个邻居感兴趣的话,我们可以使⽤分区来完成,只需
要在距离平⽅矩阵中对每⾏进⾏
分区,只需要调⽤ np.argpartition 函数即可:
K
K +1
In [22]: K = 2
nearest_partition = np.argpartition(dist_sq, K + 1, axis=1)
In order to visualize this network of neighbors, let's quickly plot the points along with lines representing the connections
from each point to its two nearest neighbors:
为了展⽰最近邻的⽹络结构,我们在图中为每个点和它最近的两个点之间连上线:
In [27]: plt.scatter(X[:, 0], X[:, 1], s=100)
为每个点和它最近的两个点之间连上线
#
K = 2
for i in range(X.shape[0]):
for j in nearest_partition[i, :K+1]:
#
X[i]
X[j]
#
zip
plt.plot(*zip(X[j], X[i]), color='black')
从 连线到
使⽤⼀些 的魔术⽅法画线
Each point in the plot has lines drawn to its two nearest neighbors. At first glance, it might seem strange that some of the
points have more than two lines coming out of them: this is due to the fact that if point A is one of the two nearest
neighbors of point B, this does not necessarily imply that point B is one of the two nearest neighbors of point A.
图上的每个点都和与它最近的两个点相连。初看起来,你可能注意到有些点的连线可能超过2条,这很奇怪:实际原因是如果A是B的最近
两个近邻之⼀,并不代表着B也必须是A的最近两个近邻之⼀。
Although the broadcasting and row-wise sorting of this approach might seem less straightforward than writing a loop, it
turns out to be a very efficient way of operating on this data in Python. You might be tempted to do the same type of
operation by manually looping through the data and sorting each set of neighbors individually, but this would almost
certainly lead to a slower algorithm than the vectorized version we used. The beauty of this approach is that it's written in
a way that's agnostic to the size of the input data: we could just as easily compute the neighbors among 100 or 1,000,000
points in any number of dimensions, and the code would look the same.
虽然使⽤⼴播和逐⾏排序的⽅式完成任务可能没有使⽤循环来的直观,但是在Python中这是⼀种⾮常有效的⽅式。你可能忍不住使⽤循环
的⽅式对每个点去计算它相应的最近邻,但是这种⽅式⼏乎肯定会⽐我们前⾯使⽤的向量化⽅案要慢很多。向量化的解法还有⼀个优点,
那就是它不关⼼数据的尺⼨:我们可以使⽤同样的代码和⽅法计算100个点或1,000,000个点以及任意维度数的数据的最近邻。
Finally, I'll note that when doing very large nearest neighbor searches, there are tree-based and/or approximate
algorithms that can scale as O[N log N ] or better rather than the O[N 2 ] of the brute-force algorithm. One example of
this is the KD-Tree, implemented in Scikit-learn.
最后,需要说明的是,当对⼀个⾮常⼤的数据集进⾏最近邻搜索时,还有⼀种基于树或相似的算法能够将时间复杂度从
或更好。其中⼀个例⼦是KD-Tree。
O[N
2
]
优化到
O[N log N ]
Aside: Big-O Notation
额外内容:⼤ O 复杂度
Big-O notation is a means of describing how the number of operations required for an algorithm scales as the input grows
in size. To use it correctly is to dive deeply into the realm of computer science theory, and to carefully distinguish it from
the related small-o notation, big-θ notation, big-Ω notation, and probably many mutant thereof. While these distinctions
add precision to statements about algorithmic scaling, outside computer science theory exams and the remarks of
pedantic blog commenters, you'll rarely see such distinctions made in practice. Far more common in the data science
world is a less rigid use of big-O notation: as a general (if imprecise) description of the scaling of an algorithm. With
apologies to theorists and pedants, this is the interpretation we'll use throughout this book.
⼤O复杂度是⼀种衡量随着输⼊数据的增加,需要执⾏的操作的数量的量级情况的指标。要正确使⽤它,需要深⼊了解计算机科学的理论
知识,要和其他相关的概念如⼩O复杂度,⼤ 复杂度,⼤ 复杂度区分开来,更加不容易。虽然精确地描述出这些复杂度是属于算法的范
畴,除了学院派计算机科学理论的测验和评分以外,你在其他应⽤领域很难看到这些严格的定义和划分。在数据科学领域中,我们不会使
⽤这样死板的⼤O复杂度概念,虽然这和算法领域的概念在精确程度上有⼀定差距。带着对理论学者和学院派的歉意,本书将⼀直使⽤对
⼤O复杂度的这种⾮精确概念解释。
θ
Ω
Big-O notation, in this loose sense, tells you how much time your algorithm will take as you increase the amount of data.
If you have an O[N ] (read "order N ") algorithm that takes 1 second to operate on a list of length N=1,000, then you
should expect it to take roughly 5 seconds for a list of length N=5,000. If you have an O[N 2 ] (read "order N squared")
algorithm that takes 1 second for N=1000, then you should expect it to take about 25 seconds for N=5000.
⼤O复杂度,简单来说,会告诉你当你的数据增⼤时,你的算法运⾏需要的时间。例如你有⼀个 (英⽂读作"Order ")的算法,对
于N=1000的数据量,它需要运⾏1秒,那么对于N=5000的数据量,算法需要执⾏的时间就为5秒。如果你的算法复杂度为
(英⽂读
作"Order N squared"),对于N=1000的数据量需要运⾏1秒,那么你可以预期当数据量增⻓为N=5000时,运⾏时间为25秒。
O[N ]
N
O[N
2
]
For our purposes, the N will usually indicate some aspect of the size of the dataset (the number of points, the number of
dimensions, etc.). When trying to analyze billions or trillions of samples, the difference between O[N ] and O[N 2 ] can
be far from trivial!
对于我们的⽬标来说,N通常代表着数据集的⼤⼩(数据点的数量,维度数等)。当我们需要分析的数据样本量达到百万级或⼗亿级时,
和
之间的差距将会是巨⼤的。
O[N ]
O[N
2
]
Notice that the big-O notation by itself tells you nothing about the actual wall-clock time of a computation, but only about
its scaling as you change N. Generally, for example, an O[N ] algorithm is considered to have better scaling than an
2
O[N ] algorithm, and for good reason. But for small datasets in particular, the algorithm with better scaling might not be
faster. For example, in a given problem an O[N 2 ] algorithm might take 0.01 seconds, while a "better" O[N ] algorithm
might take 1 second. Scale up N by a factor of 1,000, though, and the O[N ] algorithm will win out.
请记住⼤O复杂度本⾝并不能告诉你实际上运算消耗的时间,它仅仅能够告诉你当N变化时,运⾏时间会怎样随之发⽣变化。通常来说,
复杂度的算法被认为肯定要⽐
复杂度的算法要好。但对于⼩的数据集来说,好的⼤O复杂度算法并不⼀定能带来更快的执⾏
效率。例如,某个特定情况下,
复杂度的算法可能需要0.01秒的运⾏时间⽽ 复杂度的算法可能需要1秒。但是如果将N增⼤
1000倍,那么
复杂度的算法将会胜出。
O[N ]
O[N
O[N
2
2
]
]
O[N ]
O[N ]
Even this loose version of Big-O notation can be very useful when comparing the performance of algorithms, and we'll
use this notation throughout the book when talking about how algorithms scale.
我们这⾥使⽤的这种⾮严格定义的⼤O复杂度对于算法的性能也是有指⽰意义的,在本书的后续部分当我们讨论到算法范畴时都会应⽤到
它。
<
⾼级索引 | ⽬录 | 格式化数据:NumPy⾥的结构化数组 >
Open in Colab
<
数组排序 | ⽬录 | 使⽤Pandas进⾏数据处理 >
Open in Colab
Structured Data: NumPy's Structured Arrays
格式化数据:NumPy⾥的格式化数组
While often our data can be well represented by a homogeneous array of values, sometimes this is not the case. This
section demonstrates the use of NumPy's structured arrays and record arrays, which provide efficient storage for
compound, heterogeneous data. While the patterns shown here are useful for simple operations, scenarios like this often
lend themselves to the use of Pandas Dataframe s, which we'll explore in Chapter 3.
虽然我们的数据很多情况下都能表⽰成同种类的数组,但是某些情况下,这是不适⽤的。本⼩节展⽰了如何使⽤NumPy的结构化数组和记
录数组,它们能够提供对于复合的,不同种类的数组的有效存储⽅式。本⼩节的内容,包括场景和操作,通常都会在Pandas的
Dataframe 中使⽤,有关内容我们会在第三章中详细讨论。
In [1]: import numpy as np
Imagine that we have several categories of data on a number of people (say, name, age, and weight), and we'd like to
store these values for use in a Python program. It would be possible to store these in three separate arrays:
考虑⼀下,我们有⼀些关于⼈的不同种类的数据(例如姓名、年龄和体重),现在我们想要将它们保存到Python程序中。当然它们可以被
保存到三个独⽴的数组之中:
In [2]: name = ['Alice', 'Bob', 'Cathy', 'Doug']
age = [25, 45, 37, 19]
weight = [55.0, 85.5, 68.0, 61.5]
But this is a bit clumsy. There's nothing here that tells us that the three arrays are related; it would be more natural if we
could use a single structure to store all of this data. NumPy can handle this through structured arrays, which are arrays
with compound data types.
显然这种做法有些原始。没有任何额外的信息让我们知道这三个数组是关联的;如果我们可以使⽤⼀个结构保存所有这些数据的话,会更
加的⾃然。NumPy使⽤结构化数组来处理这种情况,结构化数组可以⽤来存储复合的数据类型。
Recall that previously we created a simple array using an expression like this:
回忆前⾯我们创建⼀个简单数组的⽅法:
In [3]: x = np.zeros(4, dtype=int)
We can similarly create a structured array using a compound data type specification:
我们也可以类似的创建⼀个复合类型的数组,只需要指定相应的dtype数据类型即可:
In [4]: # 使⽤复合的dtype参数来创建结构化数组
data = np.zeros(4, dtype={'names':('name', 'age', 'weight'),
'formats':('U10', 'i4', 'f8')})
print(data.dtype)
[('name', '<U10'), ('age', '<i4'), ('weight', '<f8')]
Here 'U10' translates to "Unicode string of maximum length 10," 'i4' translates to "4-byte (i.e., 32 bit) integer," and
'f8' translates to "8-byte (i.e., 64 bit) float." We'll discuss other options for these type codes in the following section.
这⾥的 U10 代表着“Unicode编码的字符串,最⼤⻓度10”, i4 代表着“4字节(32⽐特)整数”, f8 代表着“8字节(64⽐特)浮点数”。本
节后⾯我们会介绍其他的类型选项。
Now that we've created an empty container array, we can fill the array with our lists of values:
现在我们已经创建了⼀个空的结构化数组,我们可以使⽤上⾯的数据列表将数据填充到数组中:
In [5]: data['name'] = name
data['age'] = age
data['weight'] = weight
print(data)
[('Alice', 25, 55. ) ('Bob', 45, 85.5) ('Cathy', 37, 68. )
('Doug', 19, 61.5)]
As we had hoped, the data is now arranged together in one convenient block of memory.
正如我们希望那样,数组的数据现在被存储在⼀整块的内存空间中。
The handy thing with structured arrays is that you can now refer to values either by index or by name:
使⽤结构化数组的⽅便的地⽅是你可以使⽤字段的名称⽽不是序号来访问元素值了:
获得所有的名字
In [6]: #
data['name']
Out[6]: array(['Alice', 'Bob', 'Cathy', 'Doug'], dtype='<U10')
获得第⼀⾏
In [7]: #
data[0]
Out[7]: ('Alice', 25, 55.)
获得最后⼀⾏的名字
In [8]: #
data[-1]['name']
Out[8]: 'Doug'
Using Boolean masking, this even allows you to do some more sophisticated operations such as filtering on age:
使⽤布尔遮盖,我们能写出更加复杂但易懂的过滤条件,⽐如年龄的过滤:
获得所有年龄⼩于 的⼈的姓名
In [9]: #
30
data[data['age'] < 30]['name']
Out[9]: array(['Alice', 'Doug'], dtype='<U10')
Note that if you'd like to do any operations that are any more complicated than these, you should probably consider the
Pandas package, covered in the next chapter. As we'll see, Pandas provides a Dataframe object, which is a structure
built on NumPy arrays that offers a variety of useful data manipulation functionality similar to what we've shown here, as
well as much, much more.
请注意,如果你想要完成的⼯作⽐上⾯的需求还要复杂的话,你应该考虑使⽤Pandas包,下⼀章的主要内容。我们将会看到,Pandas提
供了 Dataframe 对象,它是⼀个在NumPy数组的基础上构建的结构,提供了很多有⽤的数据操作功能,包括上⾯结构化数组的功能。
Creating Structured Arrays
创建结构化数组
Structured array data types can be specified in a number of ways. Earlier, we saw the dictionary method:
结构化数组的数据类型可以采⽤集中⽅式指定。前⾯我们介绍了字典的⽅式:
In [10]: np.dtype({'names':('name', 'age', 'weight'),
'formats':('U10', 'i4', 'f8')})
Out[10]: dtype([('name', '<U10'), ('age', '<i4'), ('weight', '<f8')])
For clarity, numerical types can be specified using Python types or NumPy dtype s instead:
需要说明的是,数字类型也可以通过Python类型或NumPy数据类型来指定:
In [11]: np.dtype({'names':('name', 'age', 'weight'),
'formats':((np.str_, 10), int, np.float32)})
Out[11]: dtype([('name', '<U10'), ('age', '<i8'), ('weight', '<f4')])
A compound type can also be specified as a list of tuples:
⼀个复合类型也可以使⽤⼀个元组的列表来指定:
In [12]: np.dtype([('name', 'S10'), ('age', 'i4'), ('weight', 'f8')])
Out[12]: dtype([('name', 'S10'), ('age', '<i4'), ('weight', '<f8')])
If the names of the types do not matter to you, you can specify the types alone in a comma-separated string:
如果类型的名称并不重要,你可以省略它们,你甚⾄可以在⼀个以逗号分隔的字符串中指定所有类型:
In [13]: np.dtype('S10,i4,f8')
Out[13]: dtype([('f0', 'S10'), ('f1', '<i4'), ('f2', '<f8')])
The shortened string format codes may seem confusing, but they are built on simple principles. The first (optional)
character is < or > , which means "little endian" or "big endian," respectively, and specifies the ordering convention for
significant bits. The next character specifies the type of data: characters, bytes, ints, floating points, and so on (see the
table below). The last character or characters represents the size of the object in bytes.
类型的字符串形式的缩写初看起来很困惑,但实际上它们都是依据简单原则得到的。第⼀个(可选的)字符是 < 或 > ,代表这类型是 ⼩
尾 还是 ⼤尾 ,⽤来指定存储的字节序。下⼀个字符指定数据类型:字符、字节、整数、浮点数或其他(⻅下表)。最后⼀个字符代表类
型的⻓度。
字符
'b'
'i'
'u'
'f'
'c'
'S' , 'a'
'U'
'V'
说明
举例
np.dtype('b')
字节
np.dtype('i4') == np.int32
带符号整数
np.dtype('u1') == np.uint8
⽆符号整数
np.dtype('f8') == np.int64
浮点数
复数 np.dtype('c16') == np.complex128
np.dtype('S5')
字符串
np.dtype('U') == np.str_
Unicode字符串
np.dtype('V') == np.void
原始数据
More Advanced Compound Types
⾼级复合类型
It is possible to define even more advanced compound types. For example, you can create a type where each element
contains an array or matrix of values. Here, we'll create a data type with a mat component consisting of a 3 × 3
floating-point matrix:
除此之外,还可以定义更加复杂的复合类型。例如,你可以创建⼀个类型,其中的每⼀个元素都是⼀个数组或矩阵。下⾯,创建⼀个数据
类型内含⼀个 mat 对象,是⼀个 的浮点数矩阵:
3 ×3
In [14]: tp = np.dtype([('id', 'i8'), ('mat', 'f8', (3, 3))])
X = np.zeros(1, dtype=tp)
print(X[0])
print(X['mat'][0])
(0, [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
Now each element in the X array consists of an id and a 3 × 3 matrix. Why would you use this rather than a simple
multidimensional array, or perhaps a Python dictionary? The reason is that this NumPy dtype directly maps onto a C
structure definition, so the buffer containing the array content can be accessed directly within an appropriately written C
program. If you find yourself writing a Python interface to a legacy C or Fortran library that manipulates structured data,
you'll probably find structured arrays quite useful!
X 数组中的每个元素都有⼀个 id 和⼀个
的矩阵。为什么需要这样⽤,为什么不⽤⼀个多维数组或者甚⾄是Python的字典呢?原因
是NumPy的 dtype 数据类型直接对应这⼀个C语⾔的结构体定义,因此存储这个数组的内容内容可以直接被C语⾔的程序访问到。如果你
在写访问底层C语⾔或Fortran语⾔的Python接⼝的话,你会发现这种结构化数组很有⽤。
3 ×3
RecordArrays: Structured Arrays with a Twist
记录数组:⾯向对象的结构化数组
NumPy also provides the np.recarray class, which is almost identical to the structured arrays just described, but with
one additional feature: fields can be accessed as attributes rather than as dictionary keys. Recall that we previously
accessed the ages by writing:
还提供了 np.recarray 对象,看起来基本和前⾯介绍的结构化数组相同,但是有⼀个额外的特性:字段不是使⽤字典关键字来访
问,⽽是使⽤属性进⾏访问。前⾯我们使⽤关键字来访问数组的年龄字段:
NumPy
In [15]: data['age']
Out[15]: array([25, 45, 37, 19], dtype=int32)
If we view our data as a record array instead, we can access this with slightly fewer keystrokes:
如果我们使⽤记录数组来展⽰数据化,我们可以使⽤对象属性⽅式访问年龄字段,少打⼏个字:
In [16]: data_rec = data.view(np.recarray)
data_rec.age
Out[16]: array([25, 45, 37, 19], dtype=int32)
The downside is that for record arrays, there is some extra overhead involved in accessing the fields, even when using
the same syntax. We can see this here:
这样做的缺点是,当按照对象属性来访问数组数据时,会有额外的性能损耗。下⾯的例⼦可以看到:
In [17]: %timeit data['age']
%timeit data_rec['age']
%timeit data_rec.age
95.2 ns ± 3.1 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
2.42 µs ± 187 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.13 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Whether the more convenient notation is worth the additional overhead will depend on your own application.
是使⽤更⽅便简洁的写法还是使⽤更⾼性能的写法,取决于你应⽤的需求。
On to Pandas
进⼊Pandas
This section on structured and record arrays is purposely at the end of this chapter, because it leads so well into the next
package we will cover: Pandas. Structured arrays like the ones discussed here are good to know about for certain
situations, especially in case you're using NumPy arrays to map onto binary data formats in C, Fortran, or another
language. For day-to-day use of structured data, the Pandas package is a much better choice, and we'll dive into a full
discussion of it in the chapter that follows.
本⼩节介绍的结构化和记录数组是本章的结束内容。因为它将带我们进⼊下⼀章的主要内容:Pandas。本节介绍的结构化数组在某些情况
下是有⽤的,特别是当你使⽤NumPy数组来获取C、Fortran或其他语⾔存储的⼆进制数据时。但是对于⽇常的结构化数据应⽤来说,
Pandas包是⼀个好得太多的选择,我们在下⼀章会以⼀整章的篇幅来详细介绍它。
<
数组排序 | ⽬录 | 使⽤Pandas进⾏数据处理 >
Open in Colab
<
格式化数据:NumPy⾥的结构化数组 | ⽬录 | Pandas对象简介 >
Data Manipulation with Pandas
使⽤Pandas处理数据
In the previous chapter, we dove into detail on NumPy and its ndarray object, which provides efficient storage and
manipulation of dense typed arrays in Python. Here we'll build on this knowledge by looking in detail at the data
structures provided by the Pandas library. Pandas is a newer package built on top of NumPy, and provides an efficient
implementation of a DataFrame . DataFrame s are essentially multidimensional arrays with attached row and column
labels, and often with heterogeneous types and/or missing data. As well as offering a convenient storage interface for
labeled data, Pandas implements a number of powerful data operations familiar to users of both database frameworks
and spreadsheet programs.
在上⼀章中,我们深⼊介绍了NumPy和它的 ndarray 对象,它被⽤来在Python存储和操作⾮稀疏的数组数据。以此为基础,本章将要详
细介绍Pandas库为我们提供数据结构。Pandas是⼀个在NumPy的基础上创建的第三⽅库,它提供了对于 DataFrame 对象的有效⽀持。
DataFrame 是⼀个多维的数组,其⾏和列都有标签,通常列之间都含有不同种类的数据类型或者有缺失的数据。除了提供了对于标签数
据存储的⽀持之外,Pandas还实现了数量众多的数据操作⽅法,这些⽅法⽆论对于数据库的⽤⼾还是对于⼯作表单⽤⼾⽽⾔都⾮常熟悉。
As we saw, NumPy's ndarray data structure provides essential features for the type of clean, well-organized data
typically seen in numerical computing tasks. While it serves this purpose very well, its limitations become clear when we
need more flexibility (e.g., attaching labels to data, working with missing data, etc.) and when attempting operations that
do not map well to element-wise broadcasting (e.g., groupings, pivots, etc.), each of which is an important piece of
analyzing the less structured data available in many forms in the world around us. Pandas, and in particular its Series
and DataFrame objects, builds on the NumPy array structure and provides efficient access to these sorts of "data
munging" tasks that occupy much of a data scientist's time.
正如我们前⾯看到的,NumPy的 ndarray 数据结构能为数值计算任务所需要的数据提供必不可少的功能。虽然 ndarray 的功能已经很
强⼤,但是当我们需要更多的灵活性的时候,它的缺陷就体现了出来(例如,为数据提供标签,处理缺失的数据等)。⽽且如果当需要对
数据进⾏超过⼴播能处理范畴的操作时(例如分组,数据透视等),NumPy就⽆能为⼒了。⽽上述提到的这些能⼒对于我们处理真实世界
中产⽣的⾮严格格式化数据来说是⾮常重要的。Pandas,或者更具体的来说,它的 Series 和 DataFrame 对象,在NumPy的基础上提
供了上述操作,让数据科学家能从花很多时间的这种乏味的数据处理⼯作中解脱出来。
In this chapter, we will focus on the mechanics of using Series , DataFrame , and related structures effectively. We
will use examples drawn from real datasets where appropriate, but these examples are not necessarily the focus.
我们在本章中会聚焦于了解 Series 、 DataFrame 和相关结构的机制上。例⼦中使⽤了真实的数据集进⾏说明,以⽅便理解,但是并不
需要特别关注例⼦数据本⾝。
Installing and Using Pandas
安装和使⽤Pandas
Installation of Pandas on your system requires NumPy to be installed, and if building the library from source, requires the
appropriate tools to compile the C and Cython sources on which Pandas is built. Details on this installation can be found
in the Pandas documentation. If you followed the advice outlined in the Preface and used the Anaconda stack, you
already have Pandas installed.
在你的系统上安装Pandas必要先安装NumPy,如果选择从源码进⾏安装,还需要能够编译C和Cython的⼯具,因为Pandas源码是使⽤这
两种语⾔编写的。详细的安装⽂档可以访问Pandas在线⽂档。如果你是依照序⾔中的⽅法使⽤Anaconda安装的环境,那么Pandas已经安
装好了。
Once Pandas is installed, you can import it and check the version:
安装后,你可以载⼊包并检查版本信息,验证安装是否成功:
In [1]: import pandas
pandas.__version__
Out[1]: '0.24.2'
Just as we generally import NumPy under the alias np , we will import Pandas under the alias pd :
就像我们管理将NumPy载⼊并命名为 np ⼀样,我们也惯例将Pandas载⼊并命名为 pd :
In [2]: import pandas as pd
This import convention will be used throughout the remainder of this book.
这个惯例会贯穿本书后续所有内容。
Reminder about Built-In Documentation
內建帮助及⽂档的提醒
As you read through this chapter, don't forget that IPython gives you the ability to quickly explore the contents of a
package (by using the tab-completion feature) as well as the documentation of various functions (using the ?
character). (Refer back to Help and Documentation in IPython if you need a refresher on this.)
当你阅读本章的时候,不要忘记了IPython提供了快速查看对象内容(使⽤tab⾃动补全)和帮助⽂档(使⽤ ? 语句)的⼯具。(参⻅
IPython的帮助和⽂档)
For example, to display all the contents of the pandas namespace, you can type
例如,要查看pandas命名空间中的所有内容,你可以输⼊
In [3]: pd.<TAB>
And to display Pandas's built-in documentation, you can use this:
要列⽰Pandas的內建⽂件,你可以输⼊
In [4]: pd?
More detailed documentation, along with tutorials and other resources, can be found at http://pandas.pydata.org/.
更详细的⽂档,包括教程和其他资源,可以访问http://pandas.pydata.org/。
<
格式化数据:NumPy⾥的结构化数组 | ⽬录 | Pandas对象简介 >
使⽤Pandas进⾏数据处理 | ⽬录 | 数据索引和选择 >
<
Open in Colab
Introducing Pandas Objects
对象简介
Pandas
At the very basic level, Pandas objects can be thought of as enhanced versions of NumPy structured arrays in which the
rows and columns are identified with labels rather than simple integer indices. As we will see during the course of this
chapter, Pandas provides a host of useful tools, methods, and functionality on top of the basic data structures, but nearly
everything that follows will require an understanding of what these structures are. Thus, before we go any further, let's
introduce these three fundamental Pandas data structures: the Series , DataFrame , and Index .
在最基本的层⾯上,Pandas的对象可以被认为是NumPy结构化数组的⼀个升级版本,它的⾏和列都可以使⽤标签指代,⽽不仅仅像
NumPy那样只能使⽤整数的序号。随着本章的推进,你会学习到很多Pandas提供的⼯具、⽅法和功能,但是要学习它们都需要⾸先理解
它的数据结构。因此,在这之前,让我们先来详细介绍三个Pandas数据结构的最基本概念: Series 、 DataFrame 和 Index 。
We will start our code sessions with the standard NumPy and Pandas imports:
在写其他代码前,我们先将NumPy和Pandas按照标准⽅式载⼊:
In [1]: import numpy as np
import pandas as pd
The Pandas Series Object
的Series对象
Pandas
A Pandas Series is a one-dimensional array of indexed data. It can be created from a list or array as follows:
Pandas
的 Series 是⼀个⼀维的带索引序号的数组。可以通过列表或数组进⾏创建:
In [2]: data = pd.Series([0.25, 0.5, 0.75, 1.0])
data
Out[2]: 0
0.25
1
0.50
2
0.75
3
1.00
dtype: float64
As we see in the output, the Series wraps both a sequence of values and a sequence of indices, which we can access
with the values and index attributes. The values are simply a familiar NumPy array:
我们从结果看到, Series 封装了⼀个值的序列(由列表指定)和⼀个索引序号的序列,我们可以分别通过 values 和 index 属性访问
它们。 values 属性就是你已经很熟悉的NumPy数组:
In [3]: data.values
Out[3]: array([0.25, 0.5 , 0.75, 1.
])
The index is an array-like object of type pd.Index , which we'll discuss in more detail momentarily.
是⼀个类似数组的对象,类型是 pd.Index ,我们很快会详细介绍它。
Index
In [4]: data.index
Out[4]: RangeIndex(start=0, stop=4, step=1)
Like with a NumPy array, data can be accessed by the associated index via the familiar Python square-bracket notation:
和NumPy⼀致,你可以通过Python的中括号加上相应的序号语法来访问数据值:
In [5]: data[1]
Out[5]: 0.5
In [6]: data[1:3]
Out[6]: 1
0.50
2
0.75
dtype: float64
As we will see, though, the Pandas Series is much more general and flexible than the one-dimensional NumPy array
that it emulates.
你将会看到,Pandas的 Series 会⽐它封装的⼀维NumPy数组通⽤和灵活很多。
Series as generalized NumPy array
Series
作为通⽤的NumPy数组
From what we've seen so far, it may look like the Series object is basically interchangeable with a one-dimensional
NumPy array. The essential difference is the presence of the index: while the Numpy Array has an implicitly defined
integer index used to access the values, the Pandas Series has an explicitly defined index associated with the values.
⽬前为⽌,我们看到的 Series 对象和⼀维NumPy数组似乎是可以互换的概念。两者最基本的区别是索引序号的存在机制:NumPy数组
的整数索引隐式提供的,⽽Pandas的 Series 的索引是显式定义的。
This explicit index definition gives the Series object additional capabilities. For example, the index need not be an
integer, but can consist of values of any desired type. For example, if we wish, we can use strings as an index:
显式定义的索引提供了 Series 对象额外的能⼒。例如,索引值不需要⼀定是个整数,可以⽤任何需要的数据类型来定义索引。⽐⽅说,
下⾯我们⽤字符串来作为索引:
In [7]: data = pd.Series([0.25, 0.5, 0.75, 1.0],
index=['a', 'b', 'c', 'd'])
data
Out[7]: a
0.25
b
0.50
c
0.75
d
1.00
dtype: float64
And the item access works as expected:
然后元素可以通过相应的索引值来访问:
In [8]: data['b']
Out[8]: 0.5
We can even use non-contiguous or non-sequential indices:
我们亦可以使⽤⾮连续的或⾮序列的索引值:
In [9]: data = pd.Series([0.25, 0.5, 0.75, 1.0],
index=[2, 5, 3, 7])
data
Out[9]: 2
0.25
5
0.50
3
0.75
7
1.00
dtype: float64
In [10]: data[5]
Out[10]: 0.5
Series as specialized dictionary
作为特殊的字典
Series
In this way, you can think of a Pandas Series a bit like a specialization of a Python dictionary. A dictionary is a
structure that maps arbitrary keys to a set of arbitrary values, and a Series is a structure which maps typed keys to a
set of typed values. This typing is important: just as the type-specific compiled code behind a NumPy array makes it more
efficient than a Python list for certain operations, the type information of a Pandas Series makes it much more efficient
than Python dictionaries for certain operations.
在这个层⾯上,你可以将Pandas的 Series 当成Python字典的⼀种特殊情形。Python中的字典可以将任意的关键字key和任意的值value
对应起来, Series 是⼀种能将特定类型的关键字key和特定类型的值value对应起来的字典。这种静态类型是很重要的:正如NumPy数组
的静态类型能提供编译好的代码提升对Python列表或集合的操作性能⼀样,Pandas的 Series 能提供编译好的代码提升对Python字典的
操作性能。
The Series -as-dictionary analogy can be made even more clear by constructing a Series object directly from a
Python dictionary:
⽤⼀个Python字典创建⼀个 Series ,更加⽅便理解 Series 作为⼀个字典的机制:
In [14]: population_dict = {'California': 38332521,
'Texas': 26448193,
'New York': 19651127,
'Florida': 19552860,
'Illinois': 12882135}
population = pd.Series(population_dict)
population
Out[14]: California
Texas
New York
Florida
Illinois
dtype: int64
38332521
26448193
19651127
19552860
12882135
By default, a Series will be created where the index is drawn from the sorted keys. From here, typical dictionary-style
item access can be performed:
默认情况下, Series 会以排序关键字的⽅式创建⼀个字典。然后就可以使⽤Python标准的字典语法获取值:
In [15]: population['California']
Out[15]: 38332521
Unlike a dictionary, though, the Series also supports array-style operations such as slicing:
下⾯这个操作是字典所不具有的, Series 还⽀持按照数组⽅式的操作来对字典进⾏切⽚:
In [16]: population['California':'Illinois']
Out[16]: California
Texas
New York
Florida
Illinois
dtype: int64
38332521
26448193
19651127
19552860
12882135
We'll discuss some of the quirks of Pandas indexing and slicing in Data Indexing and Selection.
我们会在数据索引和选择中更详细介绍Pandas索引和切⽚操作。
Constructing Series objects
构建Series对象
We've already seen a few ways of constructing a Pandas Series from scratch; all of them are some version of the
following:
我们已经看到⼏种构建Pandas的 Series 对象的⽅法;其语法基础都是下⾯的构造⽅法:
>>> pd.Series(data, index=index)
where index is an optional argument, and data can be one of many entities.
其中的 index 是⼀个可选的参数,⽽ data 可以使很多种的数据集合。
For example, data can be a list or NumPy array, in which case index defaults to an integer sequence:
例如, data 可以是⼀个列表或NumPy数组,在这种情况下 index 默认是⼀个整数序列:
In [17]: pd.Series([2, 4, 6])
Out[17]: 0
2
1
4
2
6
dtype: int64
data can be a scalar, which is repeated to fill the specified index:
data
可以是⼀个标量,这种情况下标量的值会填充到整个序列的index中:
In [18]: pd.Series(5, index=[100, 200, 300])
Out[18]: 100
5
200
5
300
5
dtype: int64
data can be a dictionary, in which index defaults to the sorted dictionary keys:
data
可以是⼀个字典,这种情况下 index 默认是⼀个排序的关键字key序列:
In [19]: pd.Series({2:'a', 1:'b', 3:'c'})
Out[19]: 2
a
1
b
3
c
dtype: object
In each case, the index can be explicitly set if a different result is preferred:
每种情况下,index都可以作为额外的明确指定索引的⽅式,结果也会依据index参数⽽发⽣变化:
In [20]: pd.Series({2:'a', 1:'b', 3:'c'}, index=[3, 2])
Out[20]: 3
c
2
a
dtype: object
Notice that in this case, the Series is populated only with the explicitly identified keys.
上例表明,结果中包含的数据仅是index明确指定部分。
The Pandas DataFrame Object
的DaraFrame对象
Pandas
The next fundamental structure in Pandas is the DataFrame . Like the Series object discussed in the previous
section, the DataFrame can be thought of either as a generalization of a NumPy array, or as a specialization of a
Python dictionary. We'll now take a look at each of these perspectives.
的另⼀个基础数据结构是 DataFrame 。就像刚才介绍的 Series ⼀样, DataFrame 既可以被当成是⼀种更通⽤的NumPy数
组,也可以被当成是⼀种特殊的Python字典。下⾯来分别看看。
Pandas
DataFrame as a generalized NumPy array
DataFrame
作为⼀种通⽤的NumPy数组
If a Series is an analog of a one-dimensional array with flexible indices, a DataFrame is an analog of a twodimensional array with both flexible row indices and flexible column names. Just as you might think of a two-dimensional
array as an ordered sequence of aligned one-dimensional columns, you can think of a DataFrame as a sequence of
aligned Series objects. Here, by "aligned" we mean that they share the same index.
如果说 Series 是带有灵活索引的通⽤⼀维数组的话,那么 DataFrame 就是带有灵活的⾏索引和列索引的通⽤⼆维数组。你也可以将
DataFrame 想象成⼀系列的 Series 对象堆叠在⼀起,所谓的堆叠实际上指的是这些 Series 拥有相同的索引值序列。
To demonstrate this, let's first construct a new Series listing the area of each of the five states discussed in the
previous section:
下⾯我们构建⼀个新的 Series 存储着美国5个州⾯积(和上⾯的州⼈⼝例⼦⼀致)来说明这⼀点:
In [21]: area_dict = {'California': 423967, 'Texas': 695662, 'New York': 141297,
'Florida': 170312, 'Illinois': 149995}
area = pd.Series(area_dict)
area
Out[21]: California
Texas
New York
Florida
Illinois
dtype: int64
423967
695662
141297
170312
149995
Now that we have this along with the population Series from before, we can use a dictionary to construct a single
two-dimensional object containing this information:
现在我们就有了两个 Series ,⼀个⼈⼝和⼀个⾯积,我们可以再使⽤⼀个字典来创建⼀个⼆维的对象来存储两个序列的数据:
In [22]: states = pd.DataFrame({'population': population,
'area': area})
states
Out[22]:
population
area
California
38332521
423967
Texas
26448193
695662
New York
19651127
141297
Florida
19552860
170312
Illinois
12882135
149995
Like the Series object, the DataFrame has an index attribute that gives access to the index labels:
对象也像 Series ⼀样有着 index 属性,包括所有的数据的索引标签:
DataFrame
In [23]: states.index
Out[23]: Index(['California', 'Texas', 'New York', 'Florida', 'Illinois'], dtype='object')
Additionally, the DataFrame has a columns attribute, which is an Index object holding the column labels:
因为上⾯的 DataFrame 是⼆维的,因此它额外含有⼀个 columns 属性,同样也是⼀个 Index 对象,存储这所有列的标签:
In [24]: states.columns
Out[24]: Index(['population', 'area'], dtype='object')
Thus the DataFrame can be thought of as a generalization of a two-dimensional NumPy array, where both the rows
and columns have a generalized index for accessing the data.
因此 DataFrame 也可以被看成是⼆维NumPy数组的通⽤形式,它的⾏和列都带有通⽤的索引序列⽤来访问数据。
DataFrame as specialized dictionary
DataFrame
作为特殊的字典
Similarly, we can also think of a DataFrame as a specialization of a dictionary. Where a dictionary maps a key to a
value, a DataFrame maps a column name to a Series of column data. For example, asking for the 'area'
attribute returns the Series object containing the areas we saw earlier:
类似 Series ,我们也可以将 DataFrame 看成是⼀种特殊的字典。普通的字典将⼀个关键字key映射成⼀个值value,⽽ DataFrame 将
⼀个列标签映射成⼀个 Series 对象,⾥⾯含有整列的数据。例如,访问 area 属性会返回⼀个 Series 对象包含前⾯我们放⼊的⾯积
数据:
In [25]: states['area']
Out[25]: California
423967
Texas
695662
New York
141297
Florida
170312
Illinois
149995
Name: area, dtype: int64
Notice the potential point of confusion here: in a two-dimesnional NumPy array, data[0] will return the first row. For a
DataFrame , data['col0'] will return the first column. Because of this, it is probably better to think about
DataFrame s as generalized dictionaries rather than generalized arrays, though both ways of looking at the situation
can be useful. We'll explore more flexible means of indexing DataFrame s in Data Indexing and Selection.
这⾥要注意⼀下容易混淆的地⽅:NumPy的⼆维数组中, data[0] 会返回第⼀⾏数据,⽽在 DataFrame 中, data['col0'] 会返回
第⼀列数据。正因为此,最好还是将 DataFrame 当成是⼀个特殊的字典⽽不是通⽤的⼆维数组。我们会在数据的索引和选择⼀节中详细
讨论更多更灵活的索引操作。
Constructing DataFrame objects
构建DataFrame对象
A Pandas DataFrame can be constructed in a variety of ways. Here we'll give several examples.
Pandas
中的 DataFrame 可以有多种⽅法进⾏构建。下⾯我们介绍⼏个⽅式。
From a single Series object
从单个Series对象构建
A DataFrame is a collection of Series objects, and a single-column DataFrame can be constructed from a single
Series :
是 Series 对象的集合,因此单列的 DataFrame 可以从单个的 Series 对象创建:
DataFrame
In [26]: pd.DataFrame(population, columns=['population'])
Out[26]:
population
California
38332521
Texas
26448193
New York
19651127
Florida
19552860
Illinois
12882135
From a list of dicts
从字典的列表构建
Any list of dictionaries can be made into a DataFrame . We'll use a simple list comprehension to create some data:
任何字典的列表都可以⽤来创建 DataFrame ,我们使⽤⼀个简单的列表解析表达式来创建⼀个DataFrame:
In [27]: data = [{'a': i, 'b': 2 * i}
for i in range(3)]
pd.DataFrame(data)
Out[27]:
a
b
0
0
0
1
1
2
2
2
4
Even if some keys in the dictionary are missing, Pandas will fill them in with NaN (i.e., "not a number") values:
甚⾄在某些关键字对应的值在字典中不存在的情况下,Pandas会⾃动将它们填充为 NaN (⾮数字)值:
In [28]: pd.DataFrame([{'a': 1, 'b': 2}, {'b': 3, 'c': 4}])
Out[28]:
a
b
c
0
1.0
2
NaN
1
NaN
3
4.0
From a dictionary of Series objects
从Series对象的字典构建
As we saw before, a DataFrame can be constructed from a dictionary of Series objects as well:
我们之前看到 DataFrame 可以从⼀个 Series 对象构成的字典中创建:
In [29]: pd.DataFrame({'population': population,
'area': area})
Out[29]:
population
area
California
38332521
423967
Texas
26448193
695662
New York
19651127
141297
Florida
19552860
170312
Illinois
12882135
149995
From a two-dimensional NumPy array
从⼀个⼆维NumPy数组构建
Given a two-dimensional array of data, we can create a DataFrame with any specified column and index names. If
omitted, an integer index will be used for each:
在给定⼀个⼆维NumPy数组的情况下,我们指定其相应的列和⾏的索引序列来构建⼀个 DataFrame 。如果⾏或列的index没有指定,默
认会使⽤⼀个整数索引序列来指定:
In [30]: pd.DataFrame(np.random.rand(3, 2),
columns=['foo', 'bar'],
index=['a', 'b', 'c'])
Out[30]:
foo
bar
a
0.435638
0.153130
b
0.070155
0.671968
c
0.974456
0.358945
From a NumPy structured array
从NumPy结构化数组构建
We covered structured arrays in Structured Data: NumPy's Structured Arrays. A Pandas DataFrame operates much
like a structured array, and can be created directly from one:
上⼀章最后⼀节我们介绍了结构化数组(参⻅结构化数据:NumPy结构化数组)。Pandas的 DataFrame 对象与结构化数组⾮常接近,
因此可以直接从后者构建:
In [31]: A = np.zeros(3, dtype=[('A', 'i8'), ('B', 'f8')])
A
Out[31]: array([(0, 0.), (0, 0.), (0, 0.)], dtype=[('A', '<i8'), ('B', '<f8')])
In [32]: pd.DataFrame(A)
Out[32]:
A
B
0
0
0.0
1
0
0.0
2
0
0.0
The Pandas Index Object
的Index对象
Pandas
We have seen here that both the Series and DataFrame objects contain an explicit index that lets you reference
and modify data. This Index object is an interesting structure in itself, and it can be thought of either as an immutable
array or as an ordered set (technically a multi-set, as Index objects may contain repeated values). Those views have
some interesting consequences in the operations available on Index objects. As a simple example, let's construct an
Index from a list of integers:
前⾯内容介绍的 Series 和 DataFrame 对象都包含着⼀个显式定义的索引index对象,它的作⽤就是让你快速访问和修改数据。 Index
对象是⼀个很有趣的数据结构,它可以被当成不可变的数组或者排序的集合(严格来说是多数据集合,因为 Index 允许包含重复的
值)。这两种看法在对 Index 对象进⾏操作时会产⽣⼀些很有趣的结果。先以⼀个简单的例⼦来说明,我们从整数列表构建⼀个 Index
对象:
In [33]: ind = pd.Index([2, 3, 5, 7, 11])
ind
Out[33]: Int64Index([2, 3, 5, 7, 11], dtype='int64')
Index as immutable array
作为不可变数组
Index
The Index in many ways operates like an array. For example, we can use standard Python indexing notation to
retrieve values or slices:
很多的操作都像⼀个数组。例如,我们可以使⽤标准的Python索引语法来获得值和切⽚:
Index
In [34]: ind[1]
Out[34]: 3
In [35]: ind[::2]
Out[35]: Int64Index([2, 5, 11], dtype='int64')
Index objects also have many of the attributes familiar from NumPy arrays:
对象也有很多你熟悉的NumPy数组属性:
Index
In [36]: print(ind.size, ind.shape, ind.ndim, ind.dtype)
5 (5,) 1 int64
One difference between Index objects and NumPy arrays is that indices are immutable–that is, they cannot be
modified via the normal means:
数组和 Index 对象的最⼤区别是你⽆法改变 Index 的元素值,它们是不可变的:
NumPy
In [37]: ind[1] = 0
--------------------------------------------------------------------------TypeError
Traceback (most recent call last)
<ipython-input-37-906a9fa1424c> in <module>
----> 1 ind[1] = 0
~/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in __setitem__(self, key, value)
3936
3937
def __setitem__(self, key, value):
-> 3938
raise TypeError("Index does not support mutable operations")
3939
3940
def __getitem__(self, key):
TypeError: Index does not support mutable operations
This immutability makes it safer to share indices between multiple DataFrame s and arrays, without the potential for
side effects from inadvertent index modification.
这种不变性能在多个 DataFrame 之间共享索引时提供⼀种安全性,避免因为疏忽造成的索引修改和其他的副作⽤。
Index as ordered set
作为排序集合
Index
Pandas objects are designed to facilitate operations such as joins across datasets, which depend on many aspects of set
arithmetic. The Index object follows many of the conventions used by Python's built-in set data structure, so that
unions, intersections, differences, and other combinations can be computed in a familiar way:
对象被设计成能够满⾜跨数据集进⾏操作,例如连接多个数据集查找或操作数据,这很⼤程度依赖于集合运算。 Index 对象遵循
內建的 set 数据结构的运算法则,因此并集、交集、差集和其他的集合操作也可以按照熟悉的⽅式进⾏:
Pandas
Python
In [38]: indA = pd.Index([1, 3, 5, 7, 9])
indB = pd.Index([2, 3, 5, 7, 11])
In [39]: indA & indB
#
交集
Out[39]: Int64Index([3, 5, 7], dtype='int64')
In [40]: indA | indB
#
并集
Out[40]: Int64Index([1, 2, 3, 5, 7, 9, 11], dtype='int64')
In [41]: indA ^ indB
#
互斥差集
Out[41]: Int64Index([1, 2, 9, 11], dtype='int64')
These operations may also be accessed via object methods, for example indA.intersection(indB) .
这些操作也可以通过对象的⽅法来实现,例如 indA.intersection(indB) 。
<
使⽤Pandas进⾏数据处理 | ⽬录 | 数据索引和选择 >
Open in Colab
< Pandas
对象简介 | ⽬录 | 在Pandas中操作数据 >
Open in Colab
Data Indexing and Selection
数据索引和选择
In Chapter 2, we looked in detail at methods and tools to access, set, and modify values in NumPy arrays. These
included indexing (e.g., arr[2, 1] ), slicing (e.g., arr[:, 1:5] ), masking (e.g., arr[arr > 0] ), fancy indexing
(e.g., arr[0, [1, 5]] ), and combinations thereof (e.g., arr[:, [1, 5]] ). Here we'll look at similar means of
accessing and modifying values in Pandas Series and DataFrame objects. If you have used the NumPy patterns,
the corresponding patterns in Pandas will feel very familiar, though there are a few quirks to be aware of.
在第⼆章,我们学习了使⽤NumPy⼯具在数组中获取,设置和修改元素或⼦数组的⽅法。这些⽅法包括索引(如 arr[2, 1] ),切⽚
(如 arr[:, 1:5] ),遮盖(如 arr[arr>0] ),⾼级索引(如 arr[0, [1, 5]] ),以及上述的组合(如 arr[:, [1,
5]] )。下⾯我们将介绍在Pandas中获取和修改 Series 和 DataFrame 对象的⽅法。如果你已经熟悉了NumPy的操作,那么Pandas
的操作对你来说也很容易上⼿,只需要注意⼀些特别的地⽅。
We'll start with the simple case of the one-dimensional Series object, and then move on to the more complicated twodimesnional DataFrame object.
我们会从最简单的⼀维 Series 开始学习,然后再进⼊复杂⼀些的⼆维 DataFrame 对象。
Data Selection in Series
在Series中选择数据
As we saw in the previous section, a Series object acts in many ways like a one-dimensional NumPy array, and in
many ways like a standard Python dictionary. If we keep these two overlapping analogies in mind, it will help us to
understand the patterns of data indexing and selection in these arrays.
我们上⼀节已经看到, Series 对象在很多⽅⾯都表现的像⼀个⼀维NumPy数组,也同时在很多⽅⾯表现像是⼀个标准的Python字典。
如果我们能将这两个基本概念记住,它们能帮助我们理解Series的数据索引和选择的⽅法。
Series as dictionary
将Series看成字典
Like a dictionary, the Series object provides a mapping from a collection of keys to a collection of values:
像字典⼀样, Series 对象提供了从关键字集合到值集合的映射:
In [1]: import pandas as pd
data = pd.Series([0.25, 0.5, 0.75, 1.0],
index=['a', 'b', 'c', 'd'])
data
Out[1]: a
0.25
b
0.50
c
0.75
d
1.00
dtype: float64
In [2]: data['b']
Out[2]: 0.5
We can also use dictionary-like Python expressions and methods to examine the keys/indices and values:
我们还可以使⽤标准Python字典的表达式和⽅法来检查Series的关键字和值:
In [3]: 'a' in data
Out[3]: True
In [4]: data.keys()
Out[4]: Index(['a', 'b', 'c', 'd'], dtype='object')
In [5]: list(data.items())
Out[5]: [('a', 0.25), ('b', 0.5), ('c', 0.75), ('d', 1.0)]
Series objects can even be modified with a dictionary-like syntax. Just as you can extend a dictionary by assigning to
a new key, you can extend a Series by assigning to a new index value:
对象还可以使⽤字典操作进⾏修改。就像你可以给字典的⼀个新的关键字赋值⼀样,你可以新增⼀个index关键字来扩展
。
Series
Series
In [6]: data['e'] = 1.25
data
Out[6]: a
0.25
b
0.50
c
0.75
d
1.00
e
1.25
dtype: float64
This easy mutability of the objects is a convenient feature: under the hood, Pandas is making decisions about memory
layout and data copying that might need to take place; the user generally does not need to worry about these issues.
这样简便的修改对象的⽅法是⼀个有⽤的特性:虽然在底层Pandas会对内存分配和数据复制等进⾏操作,但是⽤⼾通常不需要担⼼这⼀
点。
Series as one-dimensional array
将Series看成⼀维数组
A Series builds on this dictionary-like interface and provides array-style item selection via the same basic mechanisms
as NumPy arrays – that is, slices, masking, and fancy indexing. Examples of these are as follows:
对象构建在字典⼀样的接⼝之上,并且提供了和NumPy数组⼀样的数据选择⽅式,即切⽚,遮盖和⾼级索引。请看下⾯的例⼦:
Series
使⽤指定的索引值切⽚
In [7]: #
data['a':'c']
Out[7]: a
0.25
b
0.50
c
0.75
dtype: float64
使⽤隐式整数索引值切⽚
In [8]: #
data[0:2]
Out[8]: a
0.25
b
0.50
dtype: float64
遮盖
In [9]: #
data[(data > 0.3) & (data < 0.8)]
Out[9]: b
0.50
c
0.75
dtype: float64
⾼级索引
In [10]: #
data[['a', 'e']]
Out[10]: a
0.25
e
1.25
dtype: float64
Among these, slicing may be the source of the most confusion. Notice that when slicing with an explicit index (i.e.,
data['a':'c'] ), the final index is included in the slice, while when slicing with an implicit index (i.e., data[0:2] ),
the final index is excluded from the slice.
在上⾯的例⼦当中,切⽚可能是最容易让⼈误解的。⾸先看到使⽤指定的显式索引进⾏切⽚(例如 data['a':'c'] ),结束位置的索引
值是包含在切⽚⾥⾯的,然⽽,使⽤隐式索引进⾏切⽚(例如 data[0:2] ),结束位置的索引值是不包含在切⽚⾥⾯的。
Indexers: loc, iloc, and ix
索引符:loc,iloc 和 ix
These slicing and indexing conventions can be a source of confusion. For example, if your Series has an explicit
integer index, an indexing operation such as data[1] will use the explicit indices, while a slicing operation like
data[1:3] will use the implicit Python-style index.
仔细想⼀下,你会发现这样的切⽚和索引操作是会造成混乱的。例如,如果 Series 对象有显式的整数索引,那么 data[1] 的操作会使
⽤显式索引,但是 data[1:3] 的操作会使⽤隐式索引。
In [11]: data = pd.Series(['a', 'b', 'c'], index=[1, 3, 5])
data
Out[11]: 1
a
3
b
5
c
dtype: object
使⽤的指定的显式索引
In [12]: #
data[1]
Out[12]: 'a'
切⽚时使⽤的隐式索引
In [13]: #
data[1:3]
Out[13]: 3
b
5
c
dtype: object
Because of this potential confusion in the case of integer indexes, Pandas provides some special indexer attributes that
explicitly expose certain indexing schemes. These are not functional methods, but attributes that expose a particular
slicing interface to the data in the Series .
因为存在上⾯看到的这种混乱,Pandas提供了⼀些特殊的索引符属性来明确指定使⽤哪种索引规则。这些索引符不是函数,⽽是⽤来访
问 Series 数据的切⽚属性。
First, the loc attribute allows indexing and slicing that always references the explicit index:
⾸先, loc 属性允许⽤⼾永远使⽤显式索引来进⾏定位和切⽚:
In [14]: data.loc[1]
Out[14]: 'a'
In [15]: data.loc[1:3]
Out[15]: 1
a
3
b
dtype: object
The iloc attribute allows indexing and slicing that always references the implicit Python-style index:
iloc
属性允许⽤⼾永远使⽤隐式索引来定位和切⽚:
In [16]: data.iloc[1]
Out[16]: 'b'
In [17]: data.iloc[1:3]
Out[17]: 3
b
5
c
dtype: object
A third indexing attribute, ix , is a hybrid of the two, and for Series objects is equivalent to standard [] -based
indexing. The purpose of the ix indexer will become more apparent in the context of DataFrame objects, which we
will discuss in a moment.
第三个索引符属性 ix ,是两者的混合,对于 Series 对象来说,等同于标准的 [] 索引。 ix 索引符的意义会在 DataFrame 对象中体
现出来,我们很快就会讨论到。
One guiding principle of Python code is that "explicit is better than implicit." The explicit nature of loc and iloc make
them very useful in maintaining clean and readable code; especially in the case of integer indexes, I recommend using
these both to make code easier to read and understand, and to prevent subtle bugs due to the mixed indexing/slicing
convention.
编码的⼀⼤原则就有“明确含义优于隐含意义”。 loc 和 iloc 属性的明确含义使得它们对于维护⼲净和可读的代码⽅⾯⾮常有效;
尤其是当使⽤显⽰整数索引的情况下,作者推荐坚持使⽤它们,既能保证代码的易读性,也能防⽌因为前⾯提到的混乱情况造成的难以发
现的bug。
Python
Data Selection in DataFrame
的数据选择
DataFrame
Recall that a DataFrame acts in many ways like a two-dimensional or structured array, and in other ways like a
dictionary of Series structures sharing the same index. These analogies can be helpful to keep in mind as we explore
data selection within this structure.
回忆上⼀节,我们介绍过 DataFrame 表现得既像⼆维数组⼜像由共同的索引值组成的 Series 对象的字典。这个概念也能帮助你学习如
何在 DataFrame ⾥⾯进⾏数据选择的⽅法。
DataFrame as a dictionary
将DataFrame当成字典
The first analogy we will consider is the DataFrame as a dictionary of related Series objects. Let's return to our
example of areas and populations of states:
⾸先我们将 DataFrame 看成是相关 Series 对象组成的字典。让我们回到之前那个美国州⼈⼝和⾯积的例⼦:
In [18]: area = pd.Series({'California': 423967, 'Texas': 695662,
'New York': 141297, 'Florida': 170312,
'Illinois': 149995})
pop = pd.Series({'California': 38332521, 'Texas': 26448193,
'New York': 19651127, 'Florida': 19552860,
'Illinois': 12882135})
data = pd.DataFrame({'area':area, 'pop':pop})
data
Out[18]:
area
pop
California
423967
38332521
Texas
695662
26448193
New York
141297
19651127
Florida
170312
19552860
Illinois
149995
12882135
The individual Series that make up the columns of the DataFrame can be accessed via dictionary-style indexing of
the column name:
这个 DataFrame 中的列分别由两个独⽴的 Series 构成,它们可以使⽤字典⽅式的关键字进⾏访问:
In [19]: data['area']
Out[19]: California
423967
Texas
695662
New York
141297
Florida
170312
Illinois
149995
Name: area, dtype: int64
Equivalently, we can use attribute-style access with column names that are strings:
同样的,当列的名字是字符串时,我们也可以使⽤属性的⽅式访问:
In [20]: data.area
Out[20]: California
423967
Texas
695662
New York
141297
Florida
170312
Illinois
149995
Name: area, dtype: int64
This attribute-style column access actually accesses the exact same object as the dictionary-style access:
使⽤字典⽅式和使⽤属性⽅式访问的列对象是同⼀个:
In [21]: data.area is data['area']
Out[21]: True
Though this is a useful shorthand, keep in mind that it does not work for all cases! For example, if the column names are
not strings, or if the column names conflict with methods of the DataFrame , this attribute-style access is not possible.
For example, the DataFrame has a pop() method, so data.pop will point to this rather than the "pop" column:
虽然这是个有⽤的缩写⽅式,但是请记住属性表达式并不是通⽤的。例如,如果列名不是字符串,或者与 DataFrame 的⽅法名字发⽣冲
突,属性表达式都没法使⽤。例如, DataFrame 有 pop() ⽅法,因此, data.pop 将会指向该⽅法⽽不是 "pop" 列:
In [22]: data.pop is data['pop']
Out[22]: False
In particular, you should avoid the temptation to try column assignment via attribute (i.e., use data['pop'] = z rather
than data.pop = z ).
特别是应该避免使⽤属性表达式给列赋值(例如,应该使⽤ data['pop']=z ⽽不是 data.pop=z )。
Like with the Series objects discussed earlier, this dictionary-style syntax can also be used to modify the object, in this
case adding a new column:
与 Series 对象⼀样,你也可以通过为⼀个新的关键字赋值来向 DataFrame 中添加新的列:
In [23]: data['density'] = data['pop'] / data['area']
data
Out[23]:
area
pop
density
California
423967
38332521
90.413926
Texas
695662
26448193
38.018740
New York
141297
19651127
139.076746
Florida
170312
19552860
114.806121
Illinois
149995
12882135
85.883763
This shows a preview of the straightforward syntax of element-by-element arithmetic between Series objects; we'll dig
into this further in Operating on Data in Pandas.
这⾥展⽰了使⽤直接的语法对多个 Series 对象按元素进⾏算术运算;我们会在在Pandas中操作数据⼀节中深⼊讨论。
DataFrame as two-dimensional array
将DataFrame看成⼆维数组
As mentioned previously, we can also view the DataFrame as an enhanced two-dimensional array. We can examine
the raw underlying data array using the values attribute:
前⾯说到,我们也可以将 DataFrame 看成是⼀个扩展的⼆维数组。我们可以通过 values 属性查看 DataFrame 对象的底层数组:
In [24]: data.values
Out[24]: array([[4.23967000e+05, 3.83325210e+07, 9.04139261e+01],
[6.95662000e+05, 2.64481930e+07, 3.80187404e+01],
[1.41297000e+05, 1.96511270e+07, 1.39076746e+02],
[1.70312000e+05, 1.95528600e+07, 1.14806121e+02],
[1.49995000e+05, 1.28821350e+07, 8.58837628e+01]])
With this picture in mind, many familiar array-like observations can be done on the DataFrame itself. For example, we
can transpose the full DataFrame to swap rows and columns:
有了这个基本概念之后,很多熟悉的数组操作都可以应⽤在 DataFrame 对象上。例如,我们可以将 DataFrame 的⾏和列交换,也就是
矩阵的倒置:
In [25]: data.T
Out[25]:
California
Texas
New York
Florida
Illinois
area
4.239670e+05
6.956620e+05
1.412970e+05
1.703120e+05
1.499950e+05
pop
3.833252e+07
2.644819e+07
1.965113e+07
1.955286e+07
1.288214e+07
density
9.041393e+01
3.801874e+01
1.390767e+02
1.148061e+02
8.588376e+01
When it comes to indexing of DataFrame objects, however, it is clear that the dictionary-style indexing of columns
precludes our ability to simply treat it as a NumPy array. In particular, passing a single index to an array accesses a row:
当我们需要对 DataFrame 对象进⾏索引时,因为列所具有的字典索引⽅式,我们⽆法简单地按照NumPy数组的⽅式来处理。⽐⽅说传递
⼀个索引值来获取⼀⾏:
In [26]: data.values[0]
Out[26]: array([4.23967000e+05, 3.83325210e+07, 9.04139261e+01])
and passing a single "index" to a DataFrame accesses a column:
传递⼀个索引值来获取⼀个列:
In [27]: data['area']
Out[27]: California
423967
Texas
695662
New York
141297
Florida
170312
Illinois
149995
Name: area, dtype: int64
Thus for array-style indexing, we need another convention. Here Pandas again uses the loc , iloc , and ix
indexers mentioned earlier. Using the iloc indexer, we can index the underlying array as if it is a simple NumPy array
(using the implicit Python-style index), but the DataFrame index and column labels are maintained in the result:
因此对于数组⽅式的索引⽅式,我们需要使⽤另⼀种⽅法。Pandas仍然使⽤ loc 、 iloc 和 ix 索引符来进⾏操作。当你使⽤ iloc
时,这就是使⽤隐式索引,Pandas会把 DataFrame 当成底层的NumPy数组来处理,但⾏和列的索引值还是会保留在结果中:
In [28]: data.iloc[:3, :2]
Out[28]:
area
pop
California
423967
38332521
Texas
695662
26448193
New York
141297
19651127
Similarly, using the loc indexer we can index the underlying data in an array-like style but using the explicit index and
column names:
类似的,使⽤ loc 索引符时,我们使⽤的是明确指定的显⽰索引:
In [29]: data.loc[:'Illinois', :'pop']
Out[29]:
area
pop
California
423967
38332521
Texas
695662
26448193
New York
141297
19651127
Florida
170312
19552860
Illinois
149995
12882135
The ix indexer allows a hybrid of these two approaches:
索引符是上两种⽅式的混合体:
译者注:ix已经在新版的Pandas中已经被抛弃了,因此会有⼀个警告,也说明读者应该慎⽤这个属性。
ix
In [30]: data.ix[:3, :'pop']
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning:
.ix is deprecated. Please use
.loc for label based indexing or
.iloc for positional indexing
See the documentation here:
http://pandas.pydata.org/pandas-docs/stable/indexing.html#ix-indexer-is-deprecated
"""Entry point for launching an IPython kernel.
Out[30]:
area
pop
California
423967
38332521
Texas
695662
26448193
New York
141297
19651127
Keep in mind that for integer indices, the ix indexer is subject to the same potential sources of confusion as discussed
for integer-indexed Series objects.
请牢记对于整型的索引来说, ix 同样也会产⽣之前在 Series 中阐述的那种混乱情况。
Any of the familiar NumPy-style data access patterns can be used within these indexers. For example, in the loc
indexer we can combine masking and fancy indexing as in the following:
然后,任何NumPy中熟悉的操作都可以在上⾯的索引符中使⽤。例如, loc 索引符中我们可以结合遮盖和⾼级索引模式:
In [31]: data.loc[data.density > 100, ['pop', 'density']]
Out[31]:
pop
density
New York
19651127
139.076746
Florida
19552860
114.806121
Any of these indexing conventions may also be used to set or modify values; this is done in the standard way that you
might be accustomed to from working with NumPy:
上⾯的索引⽅式可以⽤来设置或修改数据;这可以通过你已经熟悉的NumPy的标准⽅式来进⾏:
In [32]: data.iloc[0, 2] = 90
data
Out[32]:
area
pop
density
California
423967
38332521
90.000000
Texas
695662
26448193
38.018740
New York
141297
19651127
139.076746
Florida
170312
19552860
114.806121
Illinois
149995
12882135
85.883763
To build up your fluency in Pandas data manipulation, I suggest spending some time with a simple DataFrame and
exploring the types of indexing, slicing, masking, and fancy indexing that are allowed by these various indexing
approaches.
为了锻炼你操作Pandas数据的熟练度,作者建议花些时间构建⼀个简单的 DataFrame 对象,然后在上⾯运⽤索引、切⽚、遮盖和⾼级索
引等各种操作。
Additional indexing conventions
额外索引规则
There are a couple extra indexing conventions that might seem at odds with the preceding discussion, but nevertheless
can be very useful in practice. First, while indexing refers to columns, slicing refers to rows:
除了上⾯介绍的,还有⼀些额外的索引规则在实践中也很有⽤处。⾸先索引是针对列的,⽽切⽚是针对⾏的:
In [33]: data['Florida':'Illinois']
Out[33]:
area
pop
density
Florida
170312
19552860
114.806121
Illinois
149995
12882135
85.883763
Such slices can also refer to rows by number rather than by index:
这样的切⽚操作也可以通过⾏的序号来索引:
In [34]: data[1:3]
Out[34]:
area
pop
density
Texas
695662
26448193
38.018740
New York
141297
19651127
139.076746
Similarly, direct masking operations are also interpreted row-wise rather than column-wise:
类似的,直接的遮盖操作也是对⾏的操作⽽不是对列的操作:
In [35]: data[data.density > 100]
Out[35]:
area
pop
density
New York
141297
19651127
139.076746
Florida
170312
19552860
114.806121
These two conventions are syntactically similar to those on a NumPy array, and while these may not precisely fit the mold
of the Pandas conventions, they are nevertheless quite useful in practice.
上⾯两个规则与NumPy数组语法保持⼀致,然⽽他们和Pandas⻛格可能并不完全⼀致,但是它们在实践中还是很有⽤的。
< Pandas
对象简介 | ⽬录 | 在Pandas中操作数据 >
Open in Colab
<
数据索引和选择 | ⽬录 | 处理空缺数据 >
Open in Colab
Operating on Data in Pandas
在Pandas中操作数据
One of the essential pieces of NumPy is the ability to perform quick element-wise operations, both with basic arithmetic
(addition, subtraction, multiplication, etc.) and with more sophisticated operations (trigonometric functions, exponential
and logarithmic functions, etc.). Pandas inherits much of this functionality from NumPy, and the ufuncs that we introduced
in Computation on NumPy Arrays: Universal Functions are key to this.
⼀个关键的能⼒就是它能快速的进⾏逐个元素运算,⽆论是基础算术运算(加法、减法、乘法等)还是更加复杂的运算(三⻆函
数、幂指函数、对数函数等)。Pandas当然也继承了这种能⼒,我们在使⽤Numpy计算:通⽤函数中介绍的ufuncs就是提供这种能⼒的关
键。
NumPy
Pandas includes a couple useful twists, however: for unary operations like negation and trigonometric functions, these
ufuncs will preserve index and column labels in the output, and for binary operations such as addition and multiplication,
Pandas will automatically align indices when passing the objects to the ufunc. This means that keeping the context of
data and combining data from different sources–both potentially error-prone tasks with raw NumPy arrays–become
essentially foolproof ones with Pandas. We will additionally see that there are well-defined operations between onedimensional Series structures and two-dimensional DataFrame structures.
然⽽Pandas包括⼀些NumPy不具备的特性:对于⼀元运算如取负和三⻆函数,这些ufuncs会在结果中保留原来的index和column标签;对
于⼆元运算如加法和乘法,Pandas会⾃动在结果中对参与运算的数据集进⾏索引对⻬操作。这意味着在NumPy中对于不同数据集操作时
以及需要保持数据的信息很容易发⽣错误的情况,在Pandas中就会很难会发⽣。我们还会看到针对⼀维的 Series 对象和⼆维的
DataFrame 对象都有定义良好的操作。
Ufuncs: Index Preservation
:保留索引
Ufuncs
Because Pandas is designed to work with NumPy, any NumPy ufunc will work on Pandas Series and DataFrame
objects. Let's start by defining a simple Series and DataFrame on which to demonstrate this:
因为Pandas是设计和NumPy⼀起使⽤的,因此所有的NumPy通⽤函数都可以在Pandas的 Series 和 DataFrame 对象上使⽤。⾸先我
们定义简单的 Series 和 DataFrame 对象来展⽰:
In [1]: import pandas as pd
import numpy as np
In [2]: rng = np.random.RandomState(42)
ser = pd.Series(rng.randint(0, 10, 4))
ser
Out[2]: 0
6
1
3
2
7
3
4
dtype: int64
In [3]: df = pd.DataFrame(rng.randint(0, 10, (3, 4)),
columns=['A', 'B', 'C', 'D'])
df
Out[3]:
A
B
C
D
0
6
9
2
6
1
7
4
3
7
2
7
2
5
4
If we apply a NumPy ufunc on either of these objects, the result will be another Pandas object with the indices preserved:
如果我们对上⾯的⼀个对象使⽤⼀元ufunc运算,结果会产⽣另⼀个Pandas对象,且保留了索引:
In [4]: np.exp(ser)
Out[4]: 0
403.428793
1
20.085537
2
1096.633158
3
54.598150
dtype: float64
Or, for a slightly more complex calculation:
下⾯是⼀个更加复杂的计算:
In [5]: np.sin(df * np.pi / 4)
Out[5]:
A
B
C
D
0
-1.000000
7.071068e-01
1.000000
-1.000000e+00
1
-0.707107
1.224647e-16
0.707107
-7.071068e-01
2
-0.707107
1.000000e+00
-0.707107
1.224647e-16
Any of the ufuncs discussed in Computation on NumPy Arrays: Universal Functions can be used in a similar manner.
任何我们在使⽤Numpy计算:通⽤函数中讨论过的ufuncs都可以按照类似的⽅式进⾏运算。
UFuncs: Index Alignment
:索引对⻬
Ufuncs
For binary operations on two Series or DataFrame objects, Pandas will align indices in the process of performing
the operation. This is very convenient when working with incomplete data, as we'll see in some of the examples that
follow.
对于两个 Series 或 DataFrame 进⾏⼆元运算操作,Pandas会在运算过程中会⾃动将两个数据集的索引进⾏对⻬操作。这对于我们处
理不完整的数据集的情况下⾮常⽅便,下⾯我们来看⼀些例⼦。
Index alignment in Series
对象中的索引对⻬
Series
As an example, suppose we are combining two different data sources, and find only the top three US states by area and
the top three US states by population:
假设我们从两个不同的数据源分别获得美国前三⼤⾯积和前三⼤⼈⼝的州,作为下⾯的例⼦:
In [6]: area = pd.Series({'Alaska': 1723337, 'Texas': 695662,
'California': 423967}, name='area')
population = pd.Series({'California': 38332521, 'Texas': 26448193,
'New York': 19651127}, name='population')
Let's see what happens when we divide these to compute the population density:
然后我们将⼈⼝和⾯积相除,计算各州的⼈⼝密度:
In [7]: population / area
Out[7]: Alaska
NaN
California
90.413926
New York
NaN
Texas
38.018740
dtype: float64
The resulting array contains the union of indices of the two input arrays, which could be determined using standard
Python set arithmetic on these indices:
结果数组中的索引包含了两个输⼊数组的并集,你可以通过标准的Python集合运算获得:
In [8]: area.index | population.index
Out[8]: Index(['Alaska', 'California', 'New York', 'Texas'], dtype='object')
Any item for which one or the other does not have an entry is marked with NaN , or "Not a Number," which is how
Pandas marks missing data (see further discussion of missing data in Handling Missing Data). This index matching is
implemented this way for any of Python's built-in arithmetic expressions; any missing values are filled in with NaN by
default:
两个任意输⼊数据集中对应的另⼀个数据集不存在的元素都会被设置为 NaN (⾮数字的缩写),也就是Pandas标⽰缺失数据的⽅法(在
处理空缺数据⼀节中会详细讨论)。索引的对⻬⽅式会应⽤在任何Python內建的算术运算上,任何缺失的值都会被填充成NaN:
In [9]: A = pd.Series([2, 4, 6], index=[0, 1, 2])
B = pd.Series([1, 3, 5], index=[1, 2, 3])
A + B
Out[9]: 0
NaN
1
5.0
2
9.0
3
NaN
dtype: float64
If using NaN values is not the desired behavior, the fill value can be modified using appropriate object methods in place of
the operators. For example, calling A.add(B) is equivalent to calling A + B , but allows optional explicit specification
of the fill value for any elements in A or B that might be missing:
如果填充成NaN值不是你需要的结果,你可以使⽤相应的ufunc函数来计算,然后在函数中设置相应的填充值参数。例如,调⽤
A.add(B) 等同于调⽤ A + B ,但是可以提供额外的参数来设置⽤来缺失的替换值:
In [10]: A.add(B, fill_value=0)
Out[10]: 0
2.0
1
5.0
2
9.0
3
5.0
dtype: float64
Index alignment in DataFrame
DataFrame
中的索引对⻬
A similar type of alignment takes place for both columns and indices when performing operations on DataFrame s:
类似的对⻬⽅式在对 DataFrame 操作当中会同时发⽣在列和⾏上:
In [11]: A = pd.DataFrame(rng.randint(0, 20, (2, 2)),
columns=list('AB'))
A
Out[11]:
A
B
0
1
11
1
5
1
In [13]: B = pd.DataFrame(rng.randint(0, 10, (3, 3)),
columns=list('BAC'))
B
Out[13]:
B
A
C
0
3
8
2
1
4
2
6
2
4
8
6
In [14]: A + B
Out[14]:
A
B
C
0
9.0
14.0
NaN
1
7.0
5.0
NaN
2
NaN
NaN
NaN
Notice that indices are aligned correctly irrespective of their order in the two objects, and indices in the result are sorted.
As was the case with Series , we can use the associated object's arithmetic method and pass any desired
fill_value to be used in place of missing entries. Here we'll fill with the mean of all values in A (computed by first
stacking the rows of A ):
注意不管索引在输⼊数据集中的顺序并不会影响结果当中索引的对⻬情况。与 Series 的情况⼀样,我们可以使⽤相应的ufunc函数来代
替标准运算操作,然后代⼊你需要的 fill_value 参数来代替缺失值。这⾥我们会使⽤ A 中所有值的平均值来替代空值,我们⾸先堆叠
(stack) A 的所有⾏来计算平均值:
In [15]: fill = A.stack().mean()
A.add(B, fill_value=fill)
Out[15]:
A
B
C
0
9.0
14.0
6.5
1
7.0
5.0
10.5
2
12.5
8.5
10.5
The following table lists Python operators and their equivalent Pandas object methods:
下⾯列出了Python的运算操作及其对应的Pandas⽅法:
Python
运算符
Pandas
⽅法
+
add()
-
sub() , subtract()
*
mul() , multiply()
/
truediv() , div() , divide()
//
floordiv()
%
mod()
**
pow()
Ufuncs: Operations Between DataFrame and Series
:DataFrame和Series之间的操作
Ufuncs
When performing operations between a DataFrame and a Series , the index and column alignment is similarly
maintained. Operations between a DataFrame and a Series are similar to operations between a two-dimensional
and one-dimensional NumPy array. Consider one common operation, where we find the difference of a two-dimensional
array and one of its rows:
当在 DataFrame 和 Series 之间进⾏运算操作时,⾏和列的标签对⻬机制依然有效。 DataFrame 和 Series 之间的操作类似于在⼀
维数组和⼆维数组之间进⾏操作。例如⼀个很常⻅的操作,我们想要找出⼀个⼆维数组和它其中⼀⾏的差:
In [16]: A = rng.randint(10, size=(3, 4))
A
Out[16]: array([[1, 3, 8, 1],
[9, 8, 9, 4],
[1, 3, 6, 7]])
In [17]: A - A[0]
Out[17]: array([[ 0,
[ 8,
[ 0,
0, 0,
5, 1,
0, -2,
0],
3],
6]])
According to NumPy's broadcasting rules (see Computation on Arrays: Broadcasting), subtraction between a twodimensional array and one of its rows is applied row-wise.
依据NumPy的⼴播规则(参⻅在数组上计算:⼴播),⼆维数组的每⼀⾏都会减去它⾃⾝的第⼀⾏。
In Pandas, the convention similarly operates row-wise by default:
Pandas
中,默认也是采⽤这种⼴播机制:
In [18]: df = pd.DataFrame(A, columns=list('QRST'))
df - df.iloc[0]
Out[18]:
Q
R
S
T
0
0
0
0
0
1
8
5
1
3
2
0
0
-2
6
If you would instead like to operate column-wise, you can use the object methods mentioned earlier, while specifying the
axis keyword:
如果你希望能够按照列进⾏减法,你需要使⽤对应的ufunc函数,然后指定 axis 参数:
In [19]: df.subtract(df['R'], axis=0)
Out[19]:
Q
R
S
T
0
-2
0
5
-2
1
1
0
1
-4
2
-2
0
3
4
Note that these DataFrame / Series operations, like the operations discussed above, will automatically align indices
between the two elements:
上⾯介绍的这些 DataFrame 或者 Series 操作,都会⾃动对运算的数据集进⾏索引对⻬:
In [20]: halfrow = df.iloc[0, ::2] #
halfrow
第⼀⾏的Q和S列
Out[20]: Q
1
S
8
Name: 0, dtype: int64
In [21]: df - halfrow
Out[21]:
Q
R
S
T
0
0.0
NaN
0.0
NaN
1
8.0
NaN
1.0
NaN
2
0.0
NaN
-2.0
NaN
This preservation and alignment of indices and columns means that operations on data in Pandas will always maintain
the data context, which prevents the types of silly errors that might come up when working with heterogeneous and/or
misaligned data in raw NumPy arrays.
本节介绍的⾏与列索引保留和对⻬机制说明Pandas在进⾏数据操作时会保持数据的上下⽂信息,因此可以避免同样情况下,使⽤NumPy
数组操作不同形状和异构数据时会发⽣的错误。
<
数据索引和选择 | ⽬录 | 处理空缺数据 >
Open in Colab
<
在Pandas中操作数据 | ⽬录 | 层次化的索引 >
Open in Colab
Handling Missing Data
处理缺失数据
The difference between data found in many tutorials and data in the real world is that real-world data is rarely clean and
homogeneous. In particular, many interesting datasets will have some amount of data missing. To make matters even
more complicated, different data sources may indicate missing data in different ways.
我们在许多教程⾥⾯看到的数据和真实的数据的区别就是真实的数据很少是⼲净和同质的。更寻常的情况是,很多有意思的数据集都有很
多的数据缺失。更复杂的是,不同的数据源可能有着不同指代缺失数据的⽅式。
In this section, we will discuss some general considerations for missing data, discuss how Pandas chooses to represent
it, and demonstrate some built-in Pandas tools for handling missing data in Python. Here and throughout the book, we'll
refer to missing data in general as null, NaN, or NA values.
在本节中,我们会讨论⼀些对于缺失数据的通⽤处理⽅式,介绍Pandas如何选择和表⽰这些数据,展⽰Pandas中⽤来处理缺失数据的內
建⼯具。本节和本书其他部分,我们会将这些缺失数据标⽰为null、NaN或NA。
Trade-Offs in Missing Data Conventions
缺失数据约定的权衡
There are a number of schemes that have been developed to indicate the presence of missing data in a table or
DataFrame. Generally, they revolve around one of two strategies: using a mask that globally indicates missing values, or
choosing a sentinel value that indicates a missing entry.
⽤来在数据表或DataFrame中指定和标⽰缺失数据的⽅案有很多种。通常来说,会有两种主要的策略:使⽤⼀个全局的遮盖来标⽰缺失数
据,或者选择使⽤哨兵值来标⽰缺失的元素。
In the masking approach, the mask might be an entirely separate Boolean array, or it may involve appropriation of one bit
in the data representation to locally indicate the null status of a value.
在遮盖⽅案中,遮盖层可以是⼀整个独⽴的布尔数组,⼜或者可以在数据中使⽤⼀个⽐特标⽰空值。
In the sentinel approach, the sentinel value could be some data-specific convention, such as indicating a missing integer
value with -9999 or some rare bit pattern, or it could be a more global convention, such as indicating a missing floatingpoint value with NaN (Not a Number), a special value which is part of the IEEE floating-point specification.
在哨兵值的情况下,哨兵值是某种数据特定的约定值,例如⽤-9999标⽰⼀个缺失的整数或者其他罕⻅的数值,⼜或者使⽤更加通⽤的⽅
式,⽐⽅说标⽰⼀个缺失的浮点数为NaN(⾮数字),NaN是IEEE浮点数标准中的⼀部分。
None of these approaches is without trade-offs: use of a separate mask array requires allocation of an additional Boolean
array, which adds overhead in both storage and computation. A sentinel value reduces the range of valid values that can
be represented, and may require extra (often non-optimized) logic in CPU and GPU arithmetic. Common special values
like NaN are not available for all data types.
以上解决⽅案都是有所取舍的:独⽴的遮盖数组需要更多的内存空间⽤于存储布尔数组;普通的哨兵值会缩⼩正确数据的取值范围,⽽且
需要额外的(通常是未优化的)CPU和GPU运算;通⽤的特殊值如NaN⼜⽆法应⽤于所有的数据类型上。
As in most cases where no universally optimal choice exists, different languages and systems use different conventions.
For example, the R language uses reserved bit patterns within each data type as sentinel values indicating missing data,
while the SciDB system uses an extra byte attached to every cell which indicates a NA state.
因为在⼤多数情况下并不存在普适的优选策略,因此不同的语⾔和系统都会选择使⽤不同的约定。例如,R语⾔⽤⼾在每个数据类型中保
留⼀个⽐特位作为哨兵值来标⽰缺失数据,⽽SciDB系统的⽤⼾使⽤⼀个额外的字节绑定在每个元素值上⽤于标⽰不可⽤的情况。
Missing Data in Pandas
中的缺失值
Pandas
The way in which Pandas handles missing values is constrained by its reliance on the NumPy package, which does not
have a built-in notion of NA values for non-floating-point data types.
Pandas
中⽤来处理缺失值的⽅式取决于它依赖的NumPy包,因此对于⾮浮点数类型不存在內建的缺失值标志。
Pandas could have followed R's lead in specifying bit patterns for each individual data type to indicate nullness, but this
approach turns out to be rather unwieldy. While R contains four basic data types, NumPy supports far more than this: for
example, while R has a single integer type, NumPy supports fourteen basic integer types once you account for available
precisions, signedness, and endianness of the encoding. Reserving a specific bit pattern in all available NumPy types
would lead to an unwieldy amount of overhead in special-casing various operations for various types, likely even
requiring a new fork of the NumPy package. Further, for the smaller data types (such as 8-bit integers), sacrificing a bit to
use as a mask will significantly reduce the range of values it can represent.
可以采⽤R语⾔的⽅式,即在数据值中指定⼀个⽐特位为缺失值标志,但是这种⽅案实现起来显得很笨重。因为R只有4中基本数据
类型,⽽NumPy⽀持的类型却远超这个数:例如,R只有1中整数类型,NumPy却⽀持14种不同精度、是否带符号、⼤⼩尾编码的整数类
型。保留⼀个⽐特位作为缺失值的标志,会影响到NumPy的所有类型的很多不同的操作,基本上等同于需要⼀整套新的NumPy包来⽀持新
的操作。并且在数据类型⽐较⼩的情况下(例如8⽐特的整数),这种做法会严重缩⼩数据类型可以表达的数值的范围。
Pandas
NumPy does have support for masked arrays – that is, arrays that have a separate Boolean mask array attached for
marking data as "good" or "bad." Pandas could have derived from this, but the overhead in both storage, computation,
and code maintenance makes that an unattractive choice.
当然⽀持遮盖数组,即⼀个数组包含着分散的布尔数值⽤来标⽰数据是“好的”还是“坏的”。Pandas当然也继承了这⼀点,但是存
储、计算和代码维护⽅⾯的额外需求也使得这种⽅案不是特别吸引⼈。
NumPy
With these constraints in mind, Pandas chose to use sentinels for missing data, and further chose to use two alreadyexisting Python null values: the special floating-point NaN value, and the Python None object. This choice has some
side effects, as we will see, but in practice ends up being a good compromise in most cases of interest.
因此,Pandas选择了最后⼀种⽅案,即通⽤哨兵值标⽰缺失值。更进⼀步说就是,使⽤两个已经存在的Python空值: NaN 代表特殊的浮
点数值和Python的 None 对象。这种做法当然也有⼀些副作⽤,我们后⾯也会看到,但是在实践中它被证明在⼤多数情况下都是⼀个较好
的折中⽅案。
None : Pythonic missing data
:Python的缺失值
None
The first sentinel value used by Pandas is None , a Python singleton object that is often used for missing data in Python
code. Because it is a Python object, None cannot be used in any arbitrary NumPy/Pandas array, but only in arrays with
data type 'object' (i.e., arrays of Python objects):
第⼀个被Pandas使⽤的缺失哨兵值是 None ,它是⼀个Python的单例对象,很多情况下它都作为Python代码中缺失值的标志。因为这是
⼀个Python对象, None 不能在任意的NumPy或Pandas数组中使⽤,它只能在数组的数据类型是 object 的情况下使⽤(例如,Python
对象组成的数组):
In [1]: import numpy as np
import pandas as pd
In [2]: vals1 = np.array([1, None, 3, 4])
vals1
Out[2]: array([1, None, 3, 4], dtype=object)
This dtype=object means that the best common type representation NumPy could infer for the contents of the array
is that they are Python objects. While this kind of object array is useful for some purposes, any operations on the data will
be done at the Python level, with much more overhead than the typically fast operations seen for arrays with native types:
这⾥的 dtype=object 表⽰这个NumPy数组的元素类型是Python的对象。虽然这种类型的对象数组在某些场景中很有⽤,任何数据的操
作都会在Python层⾯进⾏,这会⽐NumPy其他基础类型进⾏的快速操作消耗更多的执⾏时间:
In [3]: for dtype in ['object', 'int']:
print("dtype =", dtype)
%timeit np.arange(1E6, dtype=dtype).sum()
print()
dtype = object
106 ms ± 645 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
dtype = int
2.55 ms ± 8.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
The use of Python objects in an array also means that if you perform aggregations like sum() or min() across an
array with a None value, you will generally get an error:
⽽且使⽤Python对象作为数组数据类型的话,当使⽤聚合操作如 sum() 或 min() 的时候,如果碰到了 None 值,那就会产⽣⼀个错
误:
In [4]: vals1.sum()
--------------------------------------------------------------------------TypeError
Traceback (most recent call last)
<ipython-input-4-30a3fc8c6726> in <module>
----> 1 vals1.sum()
~/anaconda3/lib/python3.7/site-packages/numpy/core/_methods.py in _sum(a, axis, dtype, out, keepdims,
initial)
34 def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
35
initial=_NoValue):
---> 36
return umr_sum(a, axis, dtype, out, keepdims, initial)
37
38 def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'
This reflects the fact that addition between an integer and None is undefined.
错误的原因是整数和 None 对象之间进⾏加法运算是未定义的。
NaN : Missing numerical data
NaN
:缺失的数值类型数据
The other missing data representation, NaN (acronym for Not a Number), is different; it is a special floating-point value
recognized by all systems that use the standard IEEE floating-point representation:
另外⼀个缺失的数据表现形式 NaN (⾮数字的缩写),能被所有⽀持IEEE浮点数标准的系统所识别:
In [5]: vals2 = np.array([1, np.nan, 3, 4])
vals2.dtype
Out[5]: dtype('float64')
Notice that NumPy chose a native floating-point type for this array: this means that unlike the object array from before,
this array supports fast operations pushed into compiled code. You should be aware that NaN is a bit like a data virus–it
infects any other object it touches. Regardless of the operation, the result of arithmetic with NaN will be another NaN :
使⽤原始的浮点类型来存储这个数组:这意味着不像前⾯的对象数组,这个数组⽀持使⽤编译代码来进⾏快速运算。你应该了解
就像⼀个数据的病毒,它会传染到任何接触到的数据。不论运算是哪种类型, NaN 参与的算术运算的结果都会是另⼀个 NaN :
NumPy
NaN
到
In [6]: 1 + np.nan
Out[6]: nan
In [7]: 0 *
np.nan
Out[7]: nan
Note that this means that aggregates over the values are well defined (i.e., they don't result in an error) but not always
useful:
因此对于这个数组进⾏的聚合操作是良好定义的(意思是不会发⽣错误),但是却并不⼗分有意义:
In [8]: vals2.sum(), vals2.min(), vals2.max()
Out[8]: (nan, nan, nan)
NumPy does provide some special aggregations that will ignore these missing values:
还提供了⼀些特殊的聚合函数可以⽤来忽略这些缺失值:
NumPy
In [9]: np.nansum(vals2), np.nanmin(vals2), np.nanmax(vals2)
Out[9]: (8.0, 1.0, 4.0)
Keep in mind that NaN is specifically a floating-point value; there is no equivalent NaN value for integers, strings, or
other types.
请记住 NaN 是⼀个特殊的浮点数值;对于整数、字符串或者其他类型来说都没有对应的值。
NaN and None in Pandas
中的NaN和None
Pandas
NaN and None both have their place, and Pandas is built to handle the two of them nearly interchangeably, converting
between them where appropriate:
NaN
和 None 在Pandas都可以使⽤,⽽且Pandas基本上将两者进⾏等同处理,可以在合适的情况下互相转换:
In [10]: pd.Series([1, np.nan, 2, None])
Out[10]: 0
1.0
1
NaN
2
2.0
3
NaN
dtype: float64
For types that don't have an available sentinel value, Pandas automatically type-casts when NA values are present. For
example, if we set a value in an integer array to np.nan , it will automatically be upcast to a floating-point type to
accommodate the NA:
对于哪些没有通⽤哨兵值的类型,Pandas在发现出现了NA值的情况下会⾃动对它们进⾏类型转换。例如,如果我们在⼀个整数数组中设
置了⼀个 np.nan 值,整个数组会⾃动向上扩展为浮点类型:
In [11]: x = pd.Series(range(2), dtype=int)
x
Out[11]: 0
0
1
1
dtype: int64
In [12]: x[0] = None
x
Out[12]: 0
NaN
1
1.0
dtype: float64
Notice that in addition to casting the integer array to floating point, Pandas automatically converts the None to a NaN
value. (Be aware that there is a proposal to add a native integer NA to Pandas in the future; as of this writing, it has not
been included).
上述例⼦中除了将整数类型转换为浮点数类型之外,Pandas还⾃动将 None 转换成了 NaN 值。(在本⽂写的时候,有⼀个提议在Pandas
的整数类型中加⼊⼀个NA值,不过还没有被采纳)。
While this type of magic may feel a bit hackish compared to the more unified approach to NA values in domain-specific
languages like R, the Pandas sentinel/casting approach works quite well in practice and in my experience only rarely
causes issues.
虽然这种解决⽅案对⽐起类似R语⾔那样使⽤统⼀的NA值来标⽰的⽅案来说,显得有点像魔术。但是Pandas的这种哨兵+类型转换的⽅式
在实践中运⾏良好,⽽且在作者的经验中,很少导致问题。
The following table lists the upcasting conventions in Pandas when NA values are introduced:
下表列出了Pandas在出现NA值的时候向上类型扩展的规则:
⼤类型 当NA值存在时转换规则
NA哨兵值
np.nan
浮点数
保持不变
object
保持不变 None 或 np.nan
np.nan
整数
转换为 float64
布尔
转换为 object None 或 np.nan
Keep in mind that in Pandas, string data is always stored with an object dtype.
在Pandas中,字符串数据总是使⽤ object 类型存储的。
Operating on Null Values
操作空值
As we have seen, Pandas treats None and NaN as essentially interchangeable for indicating missing or null values. To
facilitate this convention, there are several useful methods for detecting, removing, and replacing null values in Pandas
data structures. They are:
isnull() : Generate a boolean mask indicating missing values
notnull() : Opposite of isnull()
dropna() : Return a filtered version of the data
fillna() : Return a copy of the data with missing values filled or imputed
我们已经看到,Pandas将 None 和 NaN 看成是可以互相转换的缺失值或空值。与此同时,Pandas还提供了⼀些很有⽤的⽅法⽤来在数据
集中发现、移除和替换空值。这些⽅法包括:
isnull() :⽣成⼀个布尔遮盖数组指⽰缺失值的位置
notnull() : isnull() 相反⽅法
dropna() :返回⼀个过滤掉缺失值、空值的数据集
fillna() :返回⼀个数据集的副本,⾥⾯的缺失值、空值使⽤另外的值来替代
We will conclude this section with a brief exploration and demonstration of these routines.
我们在最后讨论这些⽅法作为本节的总结。
Detecting null values
检测空值
Pandas data structures have two useful methods for detecting null data: isnull() and notnull() . Either one will
return a Boolean mask over the data. For example:
Pandas
数据集有两个⽅法⽤来检测空值: isnull() 和 notnull() 。 它们都会返回⼀个布尔遮盖数组。例如:
In [13]: data = pd.Series([1, np.nan, 'hello', None])
In [14]: data.isnull()
Out[14]: 0
False
1
True
2
False
3
True
dtype: bool
As mentioned in Data Indexing and Selection, Boolean masks can be used directly as a Series or DataFrame
index:
在数据索引和选择中我们已经介绍过,布尔遮盖数组可以直接在 Series 或 DataFrame 对象上作为索引使⽤:
In [15]: data[data.notnull()]
Out[15]: 0
1
2
hello
dtype: object
The isnull() and notnull() methods produce similar Boolean results for DataFrame s.
在 DataFrame 对象上, isnull() 和 notnull() ⽅法也会产⽣相似的布尔数组。
Dropping null values
去除空值
In addition to the masking used before, there are the convenience methods, dropna() (which removes NA values) and
fillna() (which fills in NA values). For a Series , the result is straightforward:
除了上⾯的遮盖之外,还有两个很⽅便的⽅法 dropna() (移除NA值)和 fillna() (填充NA值)。对于 Series 对象来说,结果显
⽽易⻅:
In [16]: data.dropna()
Out[16]: 0
1
2
hello
dtype: object
For a DataFrame , there are more options. Consider the following DataFrame :
对于 DataFrame 对象,提供了更多选项。考虑下⾯的 DataFrame :
In [17]: df = pd.DataFrame([[1,
np.nan, 2],
[2,
3,
5],
[np.nan, 4,
6]])
df
Out[17]:
0
1
2
0
1.0
NaN
2
1
2.0
3.0
5
2
NaN
4.0
6
We cannot drop single values from a DataFrame ; we can only drop full rows or full columns. Depending on the
application, you might want one or the other, so dropna() gives a number of options for a DataFrame .
我们不能在 DataFrame 中移除单个空值;我们只能移除整⾏或者整列。取决于需求,你可能想移除⾏或列之⼀, dropna() 为
DataFrame 对象提供了⼀些参数选择。
By default, dropna() will drop all rows in which any null value is present:
默认, dropna() 会移除出现了空值的整⾏:
In [18]: df.dropna()
Out[18]:
1
0
1
2
2.0
3.0
5
Alternatively, you can drop NA values along a different axis; axis=1 drops all columns containing a null value:
你可以通过设置axis参数(如 axis=1 )来沿着不同的维度来移除空值,下⾯是移除含有空值的列的例⼦:
In [19]: df.dropna(axis='columns')
Out[19]:
2
0
2
1
5
2
6
But this drops some good data as well; you might rather be interested in dropping rows or columns with all NA values, or
a majority of NA values. This can be specified through the how or thresh parameters, which allow fine control of the
number of nulls to allow through.
但是这会移除⼀些良好的数据;你可能更希望移除那些全部是NA值或者⼤部分是NA值的⾏或列。这可以通过设置 how 或 thresh 参数来
实现,它们可以更加精细地控制移除的⾏或列包含的空值个数。
The default is how='any' , such that any row or column (depending on the axis keyword) containing a null value will
be dropped. You can also specify how='all' , which will only drop rows/columns that are all null values:
默认的情况是 how='any' ,因此任何⾏或列只要含有空值都会被移除。你可以将它设置为 how=all ,这样只有那些⾏或列全部由空值
构成的情况下才会被移除:
In [20]: df[3] = np.nan
df
Out[20]:
0
1
2
3
0
1.0
NaN
2
NaN
1
2.0
3.0
5
NaN
2
NaN
4.0
6
NaN
In [21]: df.dropna(axis='columns', how='all')
Out[21]:
0
1
2
0
1.0
NaN
2
1
2.0
3.0
5
2
NaN
4.0
6
For finer-grained control, the thresh parameter lets you specify a minimum number of non-null values for the
row/column to be kept:
如果需要更加精细的控制, thresh 参数可以让你指定结果中每⾏或列⾄少包含⾮空值的个数:
In [22]: df.dropna(axis='rows', thresh=3) #
⾏中如果有3个或以上的⾮空值,将会被保留
Out[22]:
1
0
1
2
3
2.0
3.0
5
NaN
Here the first and last row have been dropped, because they contain only two non-null values.
上例中第⼀⾏和第三⾏被移除了,因为它们都只含有2个⾮空值。
Filling null values
填充空值
Sometimes rather than dropping NA values, you'd rather replace them with a valid value. This value might be a single
number like zero, or it might be some sort of imputation or interpolation from the good values. You could do this in-place
using the isnull() method as a mask, but because it is such a common operation Pandas provides the fillna()
method, which returns a copy of the array with the null values replaced.
有时我们想要的不是移除NA值,⽽是希望将它们替换为正确的值。替换后的值可能是⼀个标量如0,或者从其他正确数值归并或插补的
值。你当然可以使⽤ isnull() 然后赋值的⽅式来实现,但是因为这个需求是如此⼴泛,Pandas提供了 fillna() ⽅法,⽤来返回⼀个
替换空值后的数据集副本。
Consider the following Series :
考虑下⾯的 Series :
In [23]: data = pd.Series([1, np.nan, 2, None, 3], index=list('abcde'))
data
Out[23]: a
1.0
b
NaN
c
2.0
d
NaN
e
3.0
dtype: float64
We can fill NA entries with a single value, such as zero:
我们可以将NA值替换成为⼀个标量,例如0:
In [24]: data.fillna(0)
Out[24]: a
1.0
b
0.0
c
2.0
d
0.0
e
3.0
dtype: float64
We can specify a forward-fill to propagate the previous value forward:
我们也可以指定填充的⽅法,如向前填充,将前⼀个值传播到下⼀个空值:
In [25]: # 向前填充
data.fillna(method='ffill')
Out[25]: a
1.0
b
1.0
c
2.0
d
2.0
e
3.0
dtype: float64
Or we can specify a back-fill to propagate the next values backward:
或者使⽤向后填充,使⽤后⼀个有效值传播到前⼀个空值:
向后填充
In [26]: #
data.fillna(method='bfill')
Out[26]: a
1.0
b
2.0
c
2.0
d
3.0
e
3.0
dtype: float64
For DataFrame s, the options are similar, but we can also specify an axis along which the fills take place:
对于 DataFrame 对象,选项是类似的,但是我们可以指定 axis 参数让填充沿着某个特定维度进⾏:
In [27]: df
Out[27]:
0
1
2
3
0
1.0
NaN
2
NaN
1
2.0
3.0
5
NaN
2
NaN
4.0
6
NaN
按列进⾏向前填充
In [28]: #
df.fillna(method='ffill', axis=1)
Out[28]:
0
1
2
3
0
1.0
1.0
2.0
2.0
1
2.0
3.0
5.0
5.0
2
NaN
4.0
6.0
6.0
Notice that if a previous value is not available during a forward fill, the NA value remains.
结果看到如果空值的前⾯没有值(此处的 df.loc[2, 0] 前⾯已经没有列,沿着列填充),那么NA值将会保留下来。
<
在Pandas中操作数据 | ⽬录 | 层次化的索引 >
Open in Colab
<
处理空缺数据 | ⽬录 | 组合数据集:Concat 和 Append >
Open in Colab
Hierarchical Indexing
层次化索引
Up to this point we've been focused primarily on one-dimensional and two-dimensional data, stored in Pandas Series
and DataFrame objects, respectively. Often it is useful to go beyond this and store higher-dimensional data–that is,
data indexed by more than one or two keys. While Pandas does provide Panel and Panel4D objects that natively
handle three-dimensional and four-dimensional data (see Aside: Panel Data), a far more common pattern in practice is to
make use of hierarchical indexing (also known as multi-indexing) to incorporate multiple index levels within a single index.
In this way, higher-dimensional data can be compactly represented within the familiar one-dimensional Series and
two-dimensional DataFrame objects.
直到⽬前为⽌,我们主要集中在⼀维和⼆维数据上,它们被存储在Pandas的 Series 和 DataFrame 对象当中。很多时候,我们需要超
越⼆维来存储更⾼维度的数据,即⽤来检索的关键字会超过1个或2个。虽然Pandas提供了 Panel 和 Panel4D 对象(参⻅额外内容:
Panel数据),但是我们在实践中更常⽤的⽅式是使⽤层次化索引(也被成为多重索引)来将多个索引层次在⼀个索引中结合起来。使⽤这
种⽅法,⾼维数据也可以⽤紧凑的⽅式表⽰成我们熟悉的⼀维 Series 和⼆维 DataFrame 对象。
In this section, we'll explore the direct creation of MultiIndex objects, considerations when indexing, slicing, and
computing statistics across multiply indexed data, and useful routines for converting between simple and hierarchically
indexed representations of your data.
在本节中,我们会讨论 多重索引 对象的直接创建⽅式,当我们在多重索引数据中进⾏索引、切⽚和统计的⽅式,还会介绍在简单索引和
多重索引之间进⾏转换的⽅法。
We begin with the standard imports:
⾸先还是先进⾏标准载⼊:
In [1]: import pandas as pd
import numpy as np
A Multiply Indexed Series
多重索引Series
Let's start by considering how we might represent two-dimensional data within a one-dimensional Series . For
concreteness, we will consider a series of data where each point has a character and numerical key.
我们从在⼀维 Series 中表⽰⼆维数据开始。我们考虑⼀个序列的数据,每个数据点都有⼀个字符串和数字关键字。
The bad way
不好的做法
Suppose you would like to track data about states from two different years. Using the Pandas tools we've already
covered, you might be tempted to simply use Python tuples as keys:
设想你想追踪州⼈⼝两个不同年份的数据。使⽤我们已经学过的Pandas⼯具,你可能会想简单的使⽤Python元组来作为key:
In [2]: index = [('California', 2000), ('California', 2010),
('New York', 2000), ('New York', 2010),
('Texas', 2000), ('Texas', 2010)]
populations = [33871648, 37253956,
18976457, 19378102,
20851820, 25145561]
pop = pd.Series(populations, index=index)
pop
Out[2]: (California, 2000)
(California, 2010)
(New York, 2000)
(New York, 2010)
(Texas, 2000)
(Texas, 2010)
dtype: int64
33871648
37253956
18976457
19378102
20851820
25145561
With this indexing scheme, you can straightforwardly index or slice the series based on this multiple index:
使⽤这种索引策略,你可以直接在series中对多个索引进⾏检索或切⽚:
In [3]: pop[('California', 2010):('Texas', 2000)]
Out[3]: (California, 2010)
(New York, 2000)
(New York, 2010)
(Texas, 2000)
dtype: int64
37253956
18976457
19378102
20851820
But the convenience ends there. For example, if you need to select all values from 2010, you'll need to do some messy
(and potentially slow) munging to make it happen:
但是这种便利性也就到此为⽌了。例如,如果你需要2010年的全部数据,就需要写⼀些没那么直观(且可能低性能的)的代码来实现了:
In [4]: pop[[i for i in pop.index if i[1] == 2010]]
Out[4]: (California, 2010)
(New York, 2010)
(Texas, 2010)
dtype: int64
37253956
19378102
25145561
This produces the desired result, but is not as clean (or as efficient for large datasets) as the slicing syntax we've grown
to love in Pandas.
结果是正确的,但是对⽐起我们已经开始喜爱的Pandas切⽚语法来说,代码并没那么易读(或者在处理⼤数据集时低效)。
The Better Way: Pandas MultiIndex
更好的⽅法:Pandas多重索引
Fortunately, Pandas provides a better way. Our tuple-based indexing is essentially a rudimentary multi-index, and the
Pandas MultiIndex type gives us the type of operations we wish to have. We can create a multi-index from the tuples
as follows:
幸运的是,Pandas提供了⼀个更好的⽅法。刚才的那个元组索引的⽅式是⼀个初级的多重索引,Pandas MultiIndex 类型提供了我们需
要的真正的多重索引功能。我们可以按照下⾯的⽅式从元组创建⼀个多重索引:
In [5]: index = pd.MultiIndex.from_tuples(index)
index
Out[5]: MultiIndex([('California', 2000),
('California', 2010),
( 'New York', 2000),
( 'New York', 2010),
(
'Texas', 2000),
(
'Texas', 2010)],
)
Notice that the MultiIndex contains multiple levels of indexing–in this case, the state names and the years, as well as
multiple labels for each data point which encode these levels.
注意上⾯的 MultiIndex 对象包含多重层级的索引,本例中为州名和年份,同时也有多个编码标签对应着每个数据点。
If we re-index our series with this MultiIndex , we see the hierarchical representation of the data:
如果我们使⽤这个 MultiIndex 对我们的series进⾏重新索引,我们可以看到这个数据集的层级展⽰:
In [6]: pop = pop.reindex(index)
pop
Out[6]: California
2000
2010
2000
2010
2000
2010
New York
Texas
33871648
37253956
18976457
19378102
20851820
25145561
dtype: int64
Here the first two columns of the Series representation show the multiple index values, while the third column shows
the data. Notice that some entries are missing in the first column: in this multi-index representation, any blank entry
indicates the same value as the line above it.
上表中 Series 的头两列代表着多重索引的值,第三列代表数据值。第⼀列中有些⾏的数据缺失了,在多重索引展⽰中,缺失的索引值数
据表⽰它与上⼀⾏具有相同的值。
Now to access all data for which the second index is 2010, we can simply use the Pandas slicing notation:
现在想要获取第⼆个索引值为2010年的数据,我们只需要简单的使⽤Pandas的切⽚语法即可:
In [7]: pop[:, 2010]
Out[7]: California
New York
Texas
dtype: int64
37253956
19378102
25145561
The result is a singly indexed array with just the keys we're interested in. This syntax is much more convenient (and the
operation is much more efficient!) than the home-spun tuple-based multi-indexing solution that we started with. We'll now
further discuss this sort of indexing operation on hieararchically indexed data.
结果变成了⼀个单⼀索引的数组,且仅带有我们感兴趣的索引。这个语法显然⽐起我们前⾯使⽤元组作为多重索引的⽅案⽅便多了(当然
性能也优异很多)。我们会深⼊讨论在层次化索引数据上进⾏操作的⽅法。
MultiIndex as extra dimension
多重索引作为额外维度
You might notice something else here: we could easily have stored the same data using a simple DataFrame with
index and column labels. In fact, Pandas is built with this equivalence in mind. The unstack() method will quickly
convert a multiply indexed Series into a conventionally indexed DataFrame :
你可能已经注意到上例中,我们可以很简单的将数据存储在⼀个简单的 DataFrame ⾥⾯,州名作为⾏索引,年份作为列索引。实际上,
Pandas已经内建了这种等同的机制。 unstack() ⽅法可以很快地将多重索引的 Series 转换成普通索引的 DataFrame :
In [8]: pop_df = pop.unstack()
pop_df
Out[8]:
2000
2010
California
33871648
37253956
New York
18976457
19378102
Texas
20851820
25145561
Naturally, the stack() method provides the opposite operation:
⾃然⽽然的, stack() ⽅法提供了相反的操作:
In [9]: pop_df.stack()
Out[9]: California
2000
2010
2000
2010
2000
2010
New York
Texas
33871648
37253956
18976457
19378102
20851820
25145561
dtype: int64
Seeing this, you might wonder why would we would bother with hierarchical indexing at all. The reason is simple: just as
we were able to use multi-indexing to represent two-dimensional data within a one-dimensional Series , we can also
use it to represent data of three or more dimensions in a Series or DataFrame . Each extra level in a multi-index
represents an extra dimension of data; taking advantage of this property gives us much more flexibility in the types of
data we can represent. Concretely, we might want to add another column of demographic data for each state at each
year (say, population under 18) ; with a MultiIndex this is as easy as adding another column to the DataFrame :
看到这⾥,你可能会疑惑为什么我们需要使⽤层次化索引。原因很简单:就像我们可以使⽤多重索引来将⼀维 Series 表⽰成⼆维数据⼀
样,我们也可以使⽤ Series 或 DataFrame 来表⽰三维或多维的数据。每个多重索引中的额外层次都代表着数据中额外的维度;利⽤这
点我们可以灵活地详细地展⽰我们的数据,例如我们希望在上⾯各州各年⼈⼝数据的基础上增加⼀列(⽐⽅说18岁以下⼈⼝数);使⽤
MultiIndex 能很简单的为 DataFrame 增加⼀列:
In [10]: pop_df = pd.DataFrame({'total': pop,
'under18': [9267089, 9284094,
4687374, 4318033,
5906301, 6879014]})
pop_df
Out[10]:
California
New York
Texas
total
under18
2000
33871648
9267089
2010
37253956
9284094
2000
18976457
4687374
2010
19378102
4318033
2000
20851820
5906301
2010
25145561
6879014
In addition, all the ufuncs and other functionality discussed in Operating on Data in Pandas work with hierarchical indices
as well. Here we compute the fraction of people under 18 by year, given the above data:
除此之外,所有在在Pandas中操作数据中介绍过的ufuncs和其他功能也可以应⽤到层次化索引数据上。下⾯我们计算18岁⼀下⼈⼝的⽐
例:
In [11]: f_u18 = pop_df['under18'] / pop_df['total']
f_u18.unstack()
Out[11]:
2000
2010
California
0.273594
0.249211
New York
0.247010
0.222831
Texas
0.283251
0.273568
This allows us to easily and quickly manipulate and explore even high-dimensional data.
这允许我们能简单和迅速的操作数据,甚⾄是⾼维度的数据。
Methods of MultiIndex Creation
多重索引创建的⽅法
The most straightforward way to construct a multiply indexed Series or DataFrame is to simply pass a list of two or
more index arrays to the constructor. For example:
最直接的构建多重索引 Series 或 DataFrame 的⽅式是向index参数传递⼀个多重列表。例如:
In [12]: df = pd.DataFrame(np.random.rand(4, 2),
index=[['a', 'a', 'b', 'b'], [1, 2, 1, 2]],
columns=['data1', 'data2'])
df
Out[12]:
a
b
data1
data2
1
0.024362
0.784210
2
0.383360
0.278085
1
0.679827
0.063426
2
0.704108
0.689651
The work of creating the MultiIndex is done in the background.
创建 MultiIndex 的⼯作会⾃动完成。
Similarly, if you pass a dictionary with appropriate tuples as keys, Pandas will automatically recognize this and use a
MultiIndex by default:
类似的,如果你使⽤元组作为关键字的字典数据传给Series,Pandas也会⾃动识别并默认使⽤ MultiIndex :
In [13]: data = {('California', 2000): 33871648,
('California', 2010): 37253956,
('Texas', 2000): 20851820,
('Texas', 2010): 25145561,
('New York', 2000): 18976457,
('New York', 2010): 19378102}
pd.Series(data)
Out[13]: California
2000
2010
2000
2010
2000
2010
Texas
New York
33871648
37253956
20851820
25145561
18976457
19378102
dtype: int64
Nevertheless, it is sometimes useful to explicitly create a MultiIndex ; we'll see a couple of these methods here.
然⽽,有时候显式地创建 MultiIndex 对象也是很有⽤的;我们下⾯会看到⼀些这些⽅法。
Explicit MultiIndex constructors
显式 MultiIndex 构造器
For more flexibility in how the index is constructed, you can instead use the class method constructors available in the
pd.MultiIndex . For example, as we did before, you can construct the MultiIndex from a simple list of arrays
giving the index values within each level:
当你需要更灵活地构建多重索引时,你可以使⽤ pd.MultiIndex 的构造器。例如,你可以使⽤多重列表来构造⼀个和前⾯⼀样的
MultiIndex 对象:
In [14]: pd.MultiIndex.from_arrays([['a', 'a', 'b', 'b'], [1, 2, 1, 2]])
Out[14]: MultiIndex([('a', 1),
('a', 2),
('b', 1),
('b', 2)],
)
You can construct it from a list of tuples giving the multiple index values of each point:
你也可以使⽤⼀个元组的列表来构建⼀个多重索引:
In [15]: pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1), ('b', 2)])
Out[15]: MultiIndex([('a', 1),
('a', 2),
('b', 1),
('b', 2)],
)
You can even construct it from a Cartesian product of single indices:
你还可以⽤两个单⼀索引的笛卡尔乘积来构造:
In [16]: pd.MultiIndex.from_product([['a', 'b'], [1, 2]])
Out[16]: MultiIndex([('a', 1),
('a', 2),
('b', 1),
('b', 2)],
)
Similarly, you can construct the MultiIndex directly using its internal encoding by passing levels (a list of lists
containing available index values for each level) and labels (a list of lists that reference these labels):
同样,你可以⽤ MultiIndex 构造器来构造多重索引,你需要传递 levels (多重列表包括每个层次的索引值)和 labels (多重列表
包括数据点的标签值)参数:
译者注:Pandas的 MultiIndex 构造器参数中labels后续可能被弃⽤,需要使⽤ codes 参数,下⾯代码进⾏了相应修改。
In [17]: pd.MultiIndex(levels=[['a', 'b'], [1, 2]],
codes=[[0, 0, 1, 1], [0, 1, 0, 1]])
Out[17]: MultiIndex([('a', 1),
('a', 2),
('b', 1),
('b', 2)],
)
Any of these objects can be passed as the index argument when creating a Series or Dataframe , or be passed
to the reindex method of an existing Series or DataFrame .
上⾯创建的这些对象都能作为 index 参数传递给 Series 或 DataFrame 构造器使⽤,或者作为 reindex ⽅法的参数提供给 Series
或 DataFrame 对象进⾏重新索引。
MultiIndex level names
MultiIndex
层次名称
Sometimes it is convenient to name the levels of the MultiIndex . This can be accomplished by passing the names
argument to any of the above MultiIndex constructors, or by setting the names attribute of the index after the fact:
为了⽅便有时需要给 MultiIndex 的不同层次进⾏命名。这可以通过在上⾯的 MultiIndex 构造⽅法中传递 names 参数,或者创建了
之后通过设置 names 属性来实现:
In [18]: pop.index.names = ['state', 'year']
pop
Out[18]: state
California
year
2000
2010
2000
2010
2000
2010
New York
Texas
33871648
37253956
18976457
19378102
20851820
25145561
dtype: int64
With more involved datasets, this can be a useful way to keep track of the meaning of various index values.
在复杂的数据集中,这种命名⽅式让不同的索引值保持它们原本的意义。
MultiIndex for columns
列的 MultiIndex
In a DataFrame , the rows and columns are completely symmetric, and just as the rows can have multiple levels of
indices, the columns can have multiple levels as well. Consider the following, which is a mock-up of some (somewhat
realistic) medical data:
在⼀个 DataFrame 中,⾏和列是完全对称的,就像前⾯看到的⾏可以有多层次的索引,列也可以有多层次的索引。看下⾯的例⼦,⽤来
模拟真实的医疗数据:
In [19]: # ⾏和列的多重索引
index = pd.MultiIndex.from_product([[2013, 2014], [1, 2]],
names=['year', 'visit'])
columns = pd.MultiIndex.from_product([['Bob', 'Guido', 'Sue'], ['HR', 'Temp']],
names=['subject', 'type'])
模拟⼀些真实数据
#
data = np.round(np.random.randn(4, 6), 1)
data[:, ::2] *= 10
data += 37
创建
#
DataFrame
health_data = pd.DataFrame(data, index=index, columns=columns)
health_data
Out[19]:
subject
Bob
Guido
Sue
type
HR
Temp
HR
Temp
HR
Temp
year
visit
2013
1
35.0
35.9
21.0
37.5
37.0
38.2
2
28.0
37.3
43.0
38.3
35.0
36.2
1
51.0
36.3
33.0
39.0
29.0
35.8
2
43.0
35.7
19.0
36.2
43.0
36.1
2014
Here we see where the multi-indexing for both rows and columns can come in very handy. This is fundamentally fourdimensional data, where the dimensions are the subject, the measurement type, the year, and the visit number. With this
in place we can, for example, index the top-level column by the person's name and get a full DataFrame containing
just that person's information:
我们看到多重索引对于⾏和列来说都是⾮常⽅便的。上⾯的数据集实际上是⼀个四维的数据,四个维度分别是受试者、测试类型、年份和
测试编号。创建了这个 DataFrame 之后,我们可以使⽤受试者的姓名来很⽅便的获取到此⼈的所有测试数据:
In [20]: health_data['Guido']
Out[20]:
type
HR
Temp
year
visit
2013
1
21.0
37.5
2
43.0
38.3
1
33.0
39.0
2
19.0
36.2
2014
For complicated records containing multiple labeled measurements across multiple times for many subjects (people,
countries, cities, etc.) use of hierarchical rows and columns can be extremely convenient!
对于这种包含着多重标签的多种维度(⼈、国家、城市等)数据。使⽤这种层次化的⾏和列的结构会⾮常⽅便。
Indexing and Slicing a MultiIndex
在 MultiIndex 上检索和切⽚
Indexing and slicing on a MultiIndex is designed to be intuitive, and it helps if you think about the indices as added
dimensions.
在 MultiIndex 上进⾏检索和切⽚设计的⾮常直观,你可以将其想象为在新增的维度上进⾏检索能帮助你理解。
We'll first look at indexing multiply indexed Series , and then multiply-indexed DataFrame s.
我们先来看⼀下多重索引 Series 的⽅法,然后再看多重索引的 DataFrame 。
Multiply indexed Series
多重索引 Series
Consider the multiply indexed Series of state populations we saw earlier:
回头再看前⾯的那个⼈⼝的多重序列 Series :
In [21]: pop
Out[21]: state
California
year
2000
2010
2000
2010
2000
2010
New York
Texas
33871648
37253956
18976457
19378102
20851820
25145561
dtype: int64
We can access single elements by indexing with multiple terms:
我们可以使⽤多重索引值获取单个元素:
In [22]: pop['California', 2000]
Out[22]: 33871648
The MultiIndex also supports partial indexing, or indexing just one of the levels in the index. The result is another
Series , with the lower-level indices maintained:
MultiIndex
同样⽀持部分检索,即仅在索引中检索其中的⼀个层次。得到的结果是另⼀个 Series 但是具有更少的层次结构:
In [23]: pop['California']
Out[23]: year
2000
33871648
2010
37253956
dtype: int64
Partial slicing is available as well, as long as the MultiIndex is sorted (see discussion in Sorted and Unsorted
Indices):
部分切⽚同样也是⽀持的,只要 MultiIndex 是排序的(参⻅有序和⽆序的索引):
In [24]: pop.loc['California':'New York']
Out[24]: state
California
year
2000
2010
2000
2010
New York
33871648
37253956
18976457
19378102
dtype: int64
With sorted indices, partial indexing can be performed on lower levels by passing an empty slice in the first index:
在有序索引的情况下,部分检索也可以⽤到低层次的索引上,只需要在第⼀个索引位置传递⼀个空的切⽚即可:
In [25]: pop[:, 2000]
Out[25]: state
California
New York
Texas
dtype: int64
33871648
18976457
20851820
Other types of indexing and selection (discussed in Data Indexing and Selection) work as well; for example, selection
based on Boolean masks:
其他类型的索引和选择(参⻅数据索引和选择)也是允许的;例如,使⽤布尔遮盖进⾏选择:
In [26]: pop[pop > 22000000]
Out[26]: state
California
year
2000
2010
2010
33871648
37253956
25145561
Texas
dtype: int64
Selection based on fancy indexing also works:
使⽤⾼级索引进⾏选择:
In [27]: pop[['California', 'Texas']]
Out[27]: state
California
year
2000
2010
2000
2010
Texas
33871648
37253956
20851820
25145561
dtype: int64
Multiply indexed DataFrames
多重索引 DataFrame
A multiply indexed DataFrame behaves in a similar manner. Consider our toy medical DataFrame from before:
对 DataFrame 进⾏多重索引也是同样的。再看前⾯我们的医疗 DataFrame 数据:
In [28]: health_data
Out[28]:
subject
Bob
Guido
Sue
type
HR
Temp
HR
Temp
HR
Temp
year
visit
2013
1
35.0
35.9
21.0
37.5
37.0
38.2
2
28.0
37.3
43.0
38.3
35.0
36.2
1
51.0
36.3
33.0
39.0
29.0
35.8
2
43.0
35.7
19.0
36.2
43.0
36.1
2014
Remember that columns are primary in a DataFrame , and the syntax used for multiply indexed Series applies to
the columns. For example, we can recover Guido's heart rate data with a simple operation:
请注意 DataFrame 中主要的索引是列,你可以将上⾯的多重索引 Series 的⽅法应⽤到 DataFrame 的列上。例如,通过⼀个简单的操
作就能获得Guido的⼼率数据:
In [29]: health_data['Guido', 'HR']
Out[29]: year
2013
visit
1
21.0
2
43.0
2014 1
33.0
2
19.0
Name: (Guido, HR), dtype: float64
Also, as with the single-index case, we can use the loc , iloc , and ix indexers introduced in Data Indexing and
Selection. For example:
同样,就像单⼀索引的情况那样,我们可以使⽤在(数据索引和选择)中介绍的 loc 、 iloc 和 ix 索引符。例如:
In [30]: health_data.iloc[:2, :2]
Out[30]:
subject
Bob
type
HR
Temp
year
visit
2013
1
35.0
35.9
2
28.0
37.3
These indexers provide an array-like view of the underlying two-dimensional data, but each individual index in loc or
iloc can be passed a tuple of multiple indices. For example:
这些索引符提供了⼀个底层⼆维数据的数组视图,并且 loc 或 iloc 中每个独⽴的索引都可以传递⼀个多重索引的元组。例如:
In [31]: health_data.loc[:, ('Bob', 'HR')]
Out[31]: year
2013
visit
1
35.0
2
28.0
2014 1
51.0
2
43.0
Name: (Bob, HR), dtype: float64
Working with slices within these index tuples is not especially convenient; trying to create a slice within a tuple will lead to
a syntax error:
使⽤这种索引元组并不是特别的⽅便;例如试图在元组中使⽤切⽚会产⽣⼀个语法错误:
In [32]: health_data.loc[(:, 1), (:, 'HR')]
File "<ipython-input-32-fb34fa30ac09>", line 1
health_data.loc[(:, 1), (:, 'HR')]
^
SyntaxError: invalid syntax
You could get around this by building the desired slice explicitly using Python's built-in slice() function, but a better
way in this context is to use an IndexSlice object, which Pandas provides for precisely this situation. For example:
解决上述问题的⽅法可以是显式调⽤Python內建的 slice() 函数,还有⼀个更好的⽅式是使⽤ IndexSlice 对象,该对象是Pandas专
⻔为这种情况准备的。例如:
In [33]: idx = pd.IndexSlice
health_data.loc[idx[:, 1], idx[:, 'HR']]
Out[33]:
subject
Bob
Guido
Sue
type
HR
HR
HR
year
visit
2013
1
35.0
21.0
37.0
2014
1
51.0
33.0
29.0
There are so many ways to interact with data in multiply indexed Series and DataFrame s, and as with many tools in
this book the best way to become familiar with them is to try them out!
访问多重索引的 Series 和 DataFrame 对象中的数据有很多种⽅法,除了阅读本书中介绍的这些⼯具外,熟悉它们的最好⽅式就是在实
践中使⽤它们。
Rearranging Multi-Indices
重新排列多重索引
One of the keys to working with multiply indexed data is knowing how to effectively transform the data. There are a
number of operations that will preserve all the information in the dataset, but rearrange it for the purposes of various
computations. We saw a brief example of this in the stack() and unstack() methods, but there are many more
ways to finely control the rearrangement of data between hierarchical indices and columns, and we'll explore them here.
使⽤多重索引数据的⼀个关键技能是掌握如何有效地转换数据形式。Pandas提供了⼀些操作能保留数据集的信息,并根据不同⽬的的计算
需要对数据进⾏重新排列。前⾯我们已经看到了 stack() 和 unstack() ⽅法的简单介绍,实际上还有更多操作可以⽤来精细控制数据
集的层次化的⾏和列索引,下⾯我们来介绍它们。
Sorted and unsorted indices
有序和⽆序的索引
Earlier, we briefly mentioned a caveat, but we should emphasize it more here. Many of the MultiIndex slicing
operations will fail if the index is not sorted. Let's take a look at this here.
前⾯我们稍微提到了有序和⽆序索引的概念,这⾥我们要强调⼀下。如果索引是⽆序的话,很多 MultiIndex 的切⽚操作都会失败。
We'll start by creating some simple multiply indexed data where the indices are not lexographically sorted:
我们来创建⼀些简单的多重索引数据,它们的索引不是具有⾃然顺序的:
In [34]: index = pd.MultiIndex.from_product([['a', 'c', 'b'], [1, 2]])
data = pd.Series(np.random.rand(6), index=index)
data.index.names = ['char', 'int']
data
Out[34]: char
a
int
1
0.923424
2
0.785119
c
1
0.878949
2
0.473416
b
1
0.505453
2
0.064504
dtype: float64
If we try to take a partial slice of this index, it will result in an error:
如果我们视图对这个 Series 对象进⾏切⽚,结果会发⽣错误:
In [35]: try:
data['a':'b']
except KeyError as e:
print(type(e))
print(e)
<class 'pandas.errors.UnsortedIndexError'>
'Key length (1) was greater than MultiIndex lexsort depth (0)'
Although it is not entirely clear from the error message, this is the result of the MultiIndex not being sorted. For various
reasons, partial slices and other similar operations require the levels in the MultiIndex to be in sorted (i.e.,
lexographical) order. Pandas provides a number of convenience routines to perform this type of sorting; examples are the
sort_index() and sortlevel() methods of the DataFrame . We'll use the simplest, sort_index() , here:
虽然错误的信息并不是那么清晰易懂,实际上这是MultiIndex没有排序的结果。许多因素决定了,当对 MultiIndex 进⾏部分的切⽚和其
他相似的操作时,都需要索引是有序(或者说具有⾃然顺序)的。Pandas提供了⽅法来对索引进⾏排序;例如 DataFrame 对象的
sort_index() 和 sortlevel() ⽅法。我们在这⾥使⽤最简单的 sort_index() ⽅法:
In [36]: data = data.sort_index()
data
Out[36]: char
a
int
1
0.923424
2
0.785119
b
1
0.505453
2
0.064504
c
1
0.878949
2
0.473416
dtype: float64
With the index sorted in this way, partial slicing will work as expected:
当索引排好序后,索引的切⽚就可以正常⼯作了:
In [37]: data['a':'b']
Out[37]: char
a
int
1
0.923424
2
0.785119
b
1
0.505453
2
0.064504
dtype: float64
Stacking and unstacking indices
索引的堆叠和拆分
As we saw briefly before, it is possible to convert a dataset from a stacked multi-index to a simple two-dimensional
representation, optionally specifying the level to use:
我们前⾯已经看到,我们可以将⼀个堆叠的多重索引的数据集拆分成⼀个简单的⼆维形式,还可以指定使⽤哪个层次进⾏拆分:
In [38]: pop.unstack(level=0)
Out[38]:
state
California
New York
Texas
2000
33871648
18976457
20851820
2010
37253956
19378102
25145561
year
In [39]: pop.unstack(level=1)
Out[39]:
year
2000
2010
California
33871648
37253956
New York
18976457
19378102
Texas
20851820
25145561
state
The opposite of unstack() is stack() , which here can be used to recover the original series:
的逆操作是 stack() ,我们可以使⽤它来重新堆叠数据集:
unstack()
In [40]: pop.unstack().stack()
Out[40]: state
California
year
2000
2010
2000
2010
2000
2010
New York
Texas
33871648
37253956
18976457
19378102
20851820
25145561
dtype: int64
Index setting and resetting
设置及重新设置索引
Another way to rearrange hierarchical data is to turn the index labels into columns; this can be accomplished with the
reset_index method. Calling this on the population dictionary will result in a DataFrame with a state and year
column holding the information that was formerly in the index. For clarity, we can optionally specify the name of the data
for the column representation:
还有⼀种重新排列层次化数据的⽅式是将⾏索引标签转为列索引标签;这可以使⽤ reset_index ⽅法来实现。在⼈⼝数据集上调⽤这个
⽅法能让结果 DataFrame 的列有层次化的州和年份标签,它们是从原来的⾏标签转换过来的。为了清晰起⻅,我们可以设置列的标签:
In [41]: pop_flat = pop.reset_index(name='population')
pop_flat
Out[41]:
state
year
population
0
California
2000
33871648
1
California
2010
37253956
2
New York
2000
18976457
3
New York
2010
19378102
4
Texas
2000
20851820
5
Texas
2010
25145561
Often when working with data in the real world, the raw input data looks like this and it's useful to build a MultiIndex
from the column values. This can be done with the set_index method of the DataFrame , which returns a multiply
indexed DataFrame :
通常当我们处理真实世界的数据的时候,我们看到的就会是如上的数据集的形式,因此从列当中构建⼀个 MultiIndex 会很有⽤。这可
以通过在 DataFrame 上使⽤ set_index ⽅法来实现,这样会返回⼀个多重索引的 DataFrame :
In [42]: pop_flat.set_index(['state', 'year'])
Out[42]:
population
state
year
California
2000
33871648
2010
37253956
2000
18976457
2010
19378102
2000
20851820
2010
25145561
New York
Texas
In practice, I find this type of reindexing to be one of the more useful patterns when encountering real-world datasets.
在实践中,作者发现当处理真实世界数据集时,这种重新索引的⽅法会经常被⽤到。
Data Aggregations on Multi-Indices
多重索引的数据聚合
We've previously seen that Pandas has built-in data aggregation methods, such as mean() , sum() , and max() .
For hierarchically indexed data, these can be passed a level parameter that controls which subset of the data the
aggregate is computed on.
前⾯我们已经了解到Pandas有內建的数据聚合⽅法,例如 mean() 、 sum() 和 max() 。对于层次化索引的数据⽽⾔,这可以通过传递
level 参数来控制数据沿着那个层次的索引来进⾏计算。
For example, let's return to our health data:
例如,再看我们的那个健康数据集:
In [43]: health_data
Out[43]:
subject
Bob
Guido
Sue
type
HR
Temp
HR
Temp
HR
Temp
year
visit
2013
1
35.0
35.9
21.0
37.5
37.0
38.2
2
28.0
37.3
43.0
38.3
35.0
36.2
1
51.0
36.3
33.0
39.0
29.0
35.8
2
43.0
35.7
19.0
36.2
43.0
36.1
2014
Perhaps we'd like to average-out the measurements in the two visits each year. We can do this by naming the index level
we'd like to explore, in this case the year:
可能我们希望能将每年测量值进⾏平均。我们可以⽤level参数指定我们需要进⾏聚合的标签,这⾥是年份:
In [44]: data_mean = health_data.mean(level='year')
data_mean
Out[44]:
subject
Bob
Guido
Sue
type
HR
Temp
HR
Temp
HR
Temp
2013
31.5
36.6
32.0
37.9
36.0
37.20
2014
47.0
36.0
26.0
37.6
36.0
35.95
year
By further making use of the axis keyword, we can take the mean among levels on the columns as well:
通过额外指定 axis 关键字,我们可以在列上沿着某个层次 level 进⾏聚合:
In [45]: data_mean.mean(axis=1, level='type')
Out[45]:
type
HR
Temp
2013
33.166667
37.233333
2014
36.333333
36.516667
year
Thus in two lines, we've been able to find the average heart rate and temperature measured among all subjects in all
visits each year. This syntax is actually a short cut to the GroupBy functionality, which we will discuss in Aggregation
and Grouping. While this is a toy example, many real-world datasets have similar hierarchical structure.
虽然只有两⾏代码,我们已经能够计算得到所有受试者每年多次测试取样的平均的⼼率和提问。这个语法实际上是 GroupBy 函数的⼀种
简略写法,我们会在聚合和分组⼀节中详细介绍。虽然这只是⼀个模拟的数据集,但是很多真实世界的数据集也有相似的层次化结构。
Aside: Panel Data
额外知识:Panel数据
Pandas has a few other fundamental data structures that we have not yet discussed, namely the pd.Panel and
pd.Panel4D objects. These can be thought of, respectively, as three-dimensional and four-dimensional generalizations
of the (one-dimensional) Series and (two-dimensional) DataFrame structures. Once you are familiar with indexing
and manipulation of data in a Series and DataFrame , Panel and Panel4D are relatively straightforward to use.
In particular, the ix , loc , and iloc indexers discussed in Data Indexing and Selection extend readily to these
higher-dimensional structures.
还有⼀些其他的基础数据结构我们没有介绍到,名称为 pd.Panel 和 pd.Panel4D 的对象。这两个对象被认为是对应于⼀维的
Series 和⼆维的 DataFrame 相应的三维和四维的通⽤数据结构。⼀旦你熟悉了 Series 和 DataFrame 的使⽤⽅法, Panel 和
Panel4D 的使⽤相对来说也是很直观的。特别的,我们在数据索引和选择中介绍过的 ix 、 loc 和 iloc 索引符在⾼维结构中也是直接
可⽤的。
Pandas
We won't cover these panel structures further in this text, as I've found in the majority of cases that multi-indexing is a
more useful and conceptually simpler representation for higher-dimensional data. Additionally, panel data is
fundamentally a dense data representation, while multi-indexing is fundamentally a sparse data representation. As the
number of dimensions increases, the dense representation can become very inefficient for the majority of real-world
datasets. For the occasional specialized application, however, these structures can be useful. If you'd like to read more
about the Panel and Panel4D structures, see the references listed in Further Resources.
我们不会在本书中继续介绍Panel结构,因为作者认为在⼤多数情况下多重索引会更加有⽤,在表现⾼维数据时概念也会显得更加简单。⽽
且更加重要的是,⾯板数据从基本上来说是密集数据,⽽多重索引从基本上来说是稀疏数据。随着维度数量的增加,使⽤密集数据⽅式表
⽰真实世界的数据是⾮常的低效的。但是对于⼀些特殊的应⽤来说,这些结构是很有⽤的。如果你希望获取更多有关 Panel 和 Panel4D
结构的内容,请查阅更多资源。
<
处理空缺数据 | ⽬录 | 组合数据集:Concat 和 Append >
Open in Colab
<
层次化的索引 | ⽬录 | 组合数据集:Merge 和 Join >
Open in Colab
Combining Datasets: Concat and Append
组合数据集:Concat 和 Append
Some of the most interesting studies of data come from combining different data sources. These operations can involve
anything from very straightforward concatenation of two different datasets, to more complicated database-style joins and
merges that correctly handle any overlaps between the datasets. Series and DataFrame s are built with this type of
operation in mind, and Pandas includes functions and methods that make this sort of data wrangling fast and
straightforward.
很多对数据进⾏的有趣的研究都来源⾃不同数据源的组合。这些组合操作包括很直接的连接两个不同的数据集,到更复杂的数据库⻛格的
联表和组合可以正确的处理数据集之间的重复部分。 Series 和 DataFrame 內建了对这些操作的⽀持,Pandas提供的函数和⽅法能够
让这种数据操作⾼效⽽直接。
Here we'll take a look at simple concatenation of Series and DataFrame s with the pd.concat function; later we'll
dive into more sophisticated in-memory merges and joins implemented in Pandas.
本节中我们会简单介绍使⽤ pd.concat 函数对 Series 和 DataFrame 进⾏连接;然后我们深⼊讨论Pandas中复杂的内存级别的合并
及联表操作。
We begin with the standard imports:
⾸先还是标准载⼊:
In [1]: import pandas as pd
import numpy as np
For convenience, we'll define this function which creates a DataFrame of a particular form that will be useful below:
为了⽅便起⻅,我们定义下⾯这个函数⽤来创建⼀个 DataFrame ,本节后续的 DataFrame 都来源⾃该函数:
In [2]: def make_df(cols, ind):
"""Quickly make a DataFrame"""
data = {c: [str(c) + str(i) for i in ind]
for c in cols}
return pd.DataFrame(data, ind)
# example DataFrame
make_df('ABC', range(3))
Out[2]:
A
B
C
0
A0
B0
C0
1
A1
B1
C1
2
A2
B2
C2
In addition, we'll create a quick class that allows us to display multiple DataFrame s side by side. The code makes use
of the special _repr_html_ method, which IPython uses to implement its rich object display:
除此之外,我们还要创建⼀个类,⽤来将多个 DataFrame 紧靠着进⾏展⽰。下⾯的代码实现了特殊的 _repr_html_ ⽅法,IPython使
⽤这个⽅法来展⽰对象的HTML格式:
In [3]: class display(object):
"""
HTML
"""
template = """<div style="float: left; padding: 10px;">
<p style='font-family:"Courier New", Courier, monospace'>{0}</p>{1}
</div>"""
def __init__(self, *args):
self.args = args
多个对象的
格式展⽰
def _repr_html_(self):
return '\n'.join(self.template.format(a, eval(a)._repr_html_())
for a in self.args)
def __repr__(self):
return '\n\n'.join(a + '\n' + repr(eval(a))
for a in self.args)
The use of this will become clearer as we continue our discussion in the following section.
这个类的使⽤⽅式会在后续进⼀步介绍。
Recall: Concatenation of NumPy Arrays
复习:NumPy数组的连接
Concatenation of Series and DataFrame objects is very similar to concatenation of Numpy arrays, which can be
done via the np.concatenate function as discussed in The Basics of NumPy Arrays. Recall that with it, you can
combine the contents of two or more arrays into a single array:
和
对象的连接与NumPy数组的连接⾮常相似,NumPy数组我们可以通过NumPy数组基础⼀节中介绍过的
函数来实现。回忆⼀下,你可以将两个或多个数组连接成⼀个数组:
Series
DataFrame
np.concatenate
In [4]: x = [1, 2, 3]
y = [4, 5, 6]
z = [7, 8, 9]
np.concatenate([x, y, z])
Out[4]: array([1, 2, 3, 4, 5, 6, 7, 8, 9])
The first argument is a list or tuple of arrays to concatenate. Additionally, it takes an axis keyword that allows you to
specify the axis along which the result will be concatenated:
第⼀个参数是需要进⾏连接的数组的元组或列表。函数还可以提供⼀个 axis 关键字参数来指定沿着哪个维度⽅向对数组进⾏连接:
In [5]: x = [[1, 2],
[3, 4]]
np.concatenate([x, x], axis=1)
Out[5]: array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Simple Concatenation with pd.concat
使⽤ pd.concat 进⾏简单连接
Pandas has a function, pd.concat() , which has a similar syntax to np.concatenate but contains a number of
options that we'll discuss momentarily:
有相应的函数 pd.concat() ,与 np.concatenate 有着相似的语法,但是有⼀些参数我们需要深⼊讨论:
# Pandas v0.24.2的函数签名
Pandas
pd.concat(
objs,
axis=0,
join='outer',
join_axes=None,
ignore_index=False,
keys=None,
levels=None,
names=None,
verify_integrity=False,
sort=None,
copy=True,
)
pd.concat() can be used for a simple concatenation of Series or DataFrame objects, just as
np.concatenate() can be used for simple concatenations of arrays:
pd.concat()
样:
可以⽤来对 Series 或 DataFrame 对象进⾏简单的连接,就像可以⽤ np.concatenate() 来对数组进⾏简单连接⼀
In [6]: ser1 = pd.Series(['A', 'B', 'C'], index=[1, 2, 3])
ser2 = pd.Series(['D', 'E', 'F'], index=[4, 5, 6])
pd.concat([ser1, ser2])
Out[6]: 1
A
2
B
3
C
4
D
5
E
6
F
dtype: object
It also works to concatenate higher-dimensional objects, such as DataFrame s:
pd.concat()
函数也可以应⽤到⾼维对象上,例如 DataFrame :
In [7]: df1 = make_df('AB', [1, 2])
df2 = make_df('AB', [3, 4])
display('df1', 'df2', 'pd.concat([df1, df2])')
Out[7]:
df1
df2
A
B
1
A1
B1
2
A2
B2
pd.concat([df1, df2])
A
B
3
A3
B3
4
A4
B4
A
B
1
A1
B1
2
A2
B2
3
A3
B3
4
A4
B4
By default, the concatenation takes place row-wise within the DataFrame (i.e., axis=0 ). Like np.concatenate ,
pd.concat allows specification of an axis along which concatenation will take place. Consider the following example:
默认情况下,连接会按照 DataFrame 的⾏来进⾏(即 axis=0 )。就像 np.concatenate 那样, pd.concat 允许指定沿着哪个维
度⽅向进⾏连接,看下例:
In [8]: df3 = make_df('AB', [0, 1])
df4 = make_df('CD', [0, 1])
display('df3', 'df4', "pd.concat([df3, df4], axis='columns')")
Out[8]:
df3
df4
A
B
0
A0
B0
1
A1
B1
pd.concat([df3, df4], axis='columns')
C
D
0
C0
D0
1
C1
D1
A
B
C
D
0
A0
B0
C0
D0
1
A1
B1
C1
D1
We could have equivalently specified axis=1 ; here we've used the more intuitive axis='col' .
我们也可以使⽤相同的声明⽅式 axis=1 ;这⾥我们使⽤了更加直观的⽅式 axis='columns' 。
译者注:原⽂中axis的参数是 'col' ,这个参数在新版本的Pandas中已经改为 'columns' 。
Duplicate indices
重复的⾏索引
One important difference between np.concatenate and pd.concat is that Pandas concatenation preserves
indices, even if the result will have duplicate indices! Consider this simple example:
np.contenate
和 pd.concat 的⼀个重要区别是Pandas的连接会保留⾏索引,甚⾄在结果中包含重复索引的情况下。看下例:
In [9]: x = make_df('AB', [0, 1])
y = make_df('AB', [2, 3])
y.index = x.index # make duplicate indices!
display('x', 'y', 'pd.concat([x, y])')
Out[9]:
x
y
A
B
0
A0
B0
1
A1
B1
pd.concat([x, y])
A
B
0
A2
B2
1
A3
B3
A
B
0
A0
B0
1
A1
B1
0
A2
B2
1
A3
B3
Notice the repeated indices in the result. While this is valid within DataFrame s, the outcome is often undesirable.
pd.concat() gives us a few ways to handle it.
注意看到结果中的重复索引。虽然这是 DataFrame 允许的,但是结果通常不是你希望的。 pd.concat() 提供了⼀些处理这个问题的⽅
法。
Catching the repeats as an error
将重复的索引捕获为错误
If you'd like to simply verify that the indices in the result of pd.concat() do not overlap, you can specify the
verify_integrity flag. With this set to True, the concatenation will raise an exception if there are duplicate indices.
Here is an example, where for clarity we'll catch and print the error message:
如果你希望简单的进⾏验证 pd.concat() 结果数据集中是否含有重复的索引,你可以传递参数 verify_integrity=True 参数。这时
连接结果的数据集中如果存在重复的⾏索引,将会抛出⼀个错误。下⾯这个例⼦,我们将捕获到这个错误并输出:
In [10]: try:
pd.concat([x, y], verify_integrity=True)
except ValueError as e:
print("ValueError:", e)
ValueError: Indexes have overlapping values: Int64Index([0, 1], dtype='int64')
Ignoring the index
忽略⾏索引
Sometimes the index itself does not matter, and you would prefer it to simply be ignored. This option can be specified
using the ignore_index flag. With this set to true, the concatenation will create a new integer index for the resulting
Series :
有些情况下,索引本⾝并不重要,那么可以选择忽略它们。给函数传递⼀个 ignore_index=True 的参数, pd.concat 函数会忽略连
接时的⾏索引,并在结果中重新创建⼀个整数的索引值:
In [11]: display('x', 'y', 'pd.concat([x, y], ignore_index=True)')
Out[11]:
x
y
A
B
0
A0
B0
1
A1
B1
pd.concat([x, y], ignore_index=True)
A
B
0
A2
B2
1
A3
B3
A
B
0
A0
B0
1
A1
B1
2
A2
B2
3
A3
B3
Adding MultiIndex keys
增加多重索引标签
Another option is to use the keys option to specify a label for the data sources; the result will be a hierarchically
indexed series containing the data:
还有⼀种⽅法是使⽤ keys 参数来指定不同数据集的索引标签;这时 pd.concat 的结果会是包含着连接数据集的多重索引数据集:
In [12]: display('x', 'y', "pd.concat([x, y], keys=['x', 'y'])")
Out[12]:
x
y
A
B
0
A0
B0
1
A1
B1
pd.concat([x, y], keys=['x', 'y'])
A
B
0
A2
B2
1
A3
B3
x
y
A
B
0
A0
B0
1
A1
B1
0
A2
B2
1
A3
B3
The result is a multiply indexed DataFrame , and we can use the tools discussed in Hierarchical Indexing to transform
this data into the representation we're interested in.
上例中的结果是⼀个多重索引的 DataFrame ,我们可以使⽤层次化的索引中介绍到的⽅法来转换或者展⽰连接结果的数据。
Concatenation with joins
使⽤联表⽅式连接
In the simple examples we just looked at, we were mainly concatenating DataFrame s with shared column names. In
practice, data from different sources might have different sets of column names, and pd.concat offers several options
in this case. Consider the concatenation of the following two DataFrame s, which have some (but not all!) columns in
common:
在上⾯我们看到的简单例⼦中,我们连接的数据集都具有相同的列及标签。在实际情况中,从不同源得到的数据通常具有不同的列数或者
列标签, pd.concat 提供了⼏个相应的参数帮助我们完成上⾯的任务。下例中的两个数据集只有部分(⾮全部)列和标签相同:
译者注:新版的Pandas修改了 sort 参数的默认值,后续该参数会默认为False。
In [13]: df5 = make_df('ABC', [1, 2])
df6 = make_df('BCD', [3, 4])
display('df5', 'df6', 'pd.concat([df5, df6])')
Out[13]:
df5
df6
A
B
C
1
A1
B1
C1
2
A2
B2
C2
pd.concat([df5, df6])
B
C
D
3
B3
C3
D3
4
B4
C4
D4
A
B
C
D
1
A1
B1
C1
NaN
2
A2
B2
C2
NaN
3
NaN
B3
C3
D3
4
NaN
B4
C4
D4
By default, the entries for which no data is available are filled with NA values. To change this, we can specify one of
several options for the join and join_axes parameters of the concatenate function. By default, the join is a union of
the input columns ( join='outer' ), but we can change this to an intersection of the columns using join='inner' :
默认情况下,那些对应源数据集中不存在的元素值,将被填充为NA值。如果想改变默认⾏为,我们可以通过指定 join 和 join_axes 参
数来实现。 join 参数默认为 join='outer' ,就像我们上⾯看到的情况,结果是数据集的并集;如果将 join='inner' 传递给
pd.concat ,那么就会是数据源中相同的列保留在结果中,因此结果是数据集的交集:
In [14]: display('df5', 'df6',
"pd.concat([df5, df6], join='inner')")
Out[14]:
df5
df6
A
B
C
1
A1
B1
C1
2
A2
B2
C2
pd.concat([df5, df6], join='inner')
B
C
D
3
B3
C3
D3
4
B4
C4
D4
B
C
1
B1
C1
2
B2
C2
3
B3
C3
4
B4
C4
Another option is to directly specify the index of the remaininig colums using the join_axes argument, which takes a
list of index objects. Here we'll specify that the returned columns should be the same as those of the first input:
还可以通过另⼀个参数 join_axes 来指定结果中保留的列,该参数接受被保留索引标签的列表。下例中我们指定结果中的列和第⼀个进
⾏连接的数据集完全相同:
译者注:1.0版的pandas已经去掉了 josn_axes 关键字参数,可以通过 reindex ⽅法达到同样的⽬的,下⾯使⽤了 reindex 语法保留
了 df5 的所有列。
In [15]: display('df5', 'df6',
"pd.concat([df5, df6]).reindex(df5.columns, axis=1)")
Out[15]:
df5
df6
A
B
C
1
A1
B1
C1
2
A2
B2
C2
pd.concat([df5, df6]).reindex(df5.columns, axis=1)
B
C
D
3
B3
C3
D3
4
B4
C4
D4
A
B
C
1
A1
B1
C1
2
A2
B2
C2
3
NaN
B3
C3
4
NaN
B4
C4
The combination of options of the pd.concat function allows a wide range of possible behaviors when joining two
datasets; keep these in mind as you use these tools for your own data.
函数的参数很多,组合使⽤它们能解决组合多个数据集中的很多问题;请记住当你在⾃⼰的数据上操作时,你可以灵活地应
⽤它们,完成你的⼯作⽬标。
pd.concat
The append() method
append()
⽅法
Because direct array concatenation is so common, Series and DataFrame objects have an append method that
can accomplish the same thing in fewer keystrokes. For example, rather than calling pd.concat([df1, df2]) , you
can simply call df1.append(df2) :
因为数据集的连接操作是很普遍的, Series 和 DataFrame 对象都有⼀个 append ⽅法,它能完成和 pd.concat ⼀样的功能,并能
让让你写代码时节省⼏次敲击键盘的动作。例如你可以简单是调⽤ df1.append(df2) ⽽不是调⽤ pd.concat([df1, df2]) :
In [16]: display('df1', 'df2', 'df1.append(df2)')
Out[16]:
df1
df2
A
B
1
A1
B1
2
A2
B2
df1.append(df2)
A
B
3
A3
B3
4
A4
B4
A
B
1
A1
B1
2
A2
B2
3
A3
B3
4
A4
B4
Keep in mind that unlike the append() and extend() methods of Python lists, the append() method in Pandas
does not modify the original object–instead it creates a new object with the combined data. It also is not a very efficient
method, because it involves creation of a new index and data buffer. Thus, if you plan to do multiple append
operations, it is generally better to build a list of DataFrame s and pass them all at once to the concat() function.
最后记住不像Python列表的 append() 和 extend ⽅法,Pandas中的 append() ⽅法不会修改原始参与运算的数据集,它会为合并后
的结果创建⼀个新的对象。它也不是⼀个很⾼性能的⽅法,因为涉及到新索引和数据缓冲区的创建。因此如果你有需要连接多个数据集
时,应该避免多次使⽤ append ⽅法,⽽是将所有需要进⾏连接的数据集形成⼀个列表,并传递给 concat 函数来进⾏连接操作。
In the next section, we'll look at another more powerful approach to combining data from multiple sources, the databasestyle merges/joins implemented in pd.merge . For more information on concat() , append() , and related
functionality, see the "Merge, Join, and Concatenate" section of the Pandas documentation.
下⼀节中,我们会介绍另外⼀种更强⼤的从不同数据源组合数据的⽅法,即数据库⻛格的联表和合并 pd.merge 。需要查阅更多有关
concat() 、 append() 的知识,可以访问Pandas在线⽂档 - "合并、联表及连接"。
<
层次化的索引 | ⽬录 | 组合数据集:Merge 和 Join >
Open in Colab
<
组合数据集:Concat 和 Append | ⽬录 | 聚合与分组 >
Open in Colab
Combining Datasets: Merge and Join
组合数据集:Merge 和 Join
One essential feature offered by Pandas is its high-performance, in-memory join and merge operations. If you have ever
worked with databases, you should be familiar with this type of data interaction. The main interface for this is the
pd.merge function, and we'll see few examples of how this can work in practice.
提供的⼀个基本的特性就是它的⾼性能、内存中进⾏的联表和组合操作。如果你使⽤过数据库,你应该已经很熟悉相关的数据操作
了。
在这⽅⾯提供的主要接⼝是 pd.merge 函数,本节中我们会看到⼀些具体实践的例⼦。
Pandas
Pandas
For convenience, we will start by redefining the display() functionality from the previous section:
为了⽅便起⻅,我们重新定义上⼀节中的 display() 类,⽤来展⽰多个数据集:
In [1]: import pandas as pd
import numpy as np
class display(object):
"""Display HTML representation of multiple objects"""
template = """<div style="float: left; padding: 10px;">
<p style='font-family:"Courier New", Courier, monospace'>{0}</p>{1}
</div>"""
def __init__(self, *args):
self.args = args
def _repr_html_(self):
return '\n'.join(self.template.format(a, eval(a)._repr_html_())
for a in self.args)
def __repr__(self):
return '\n\n'.join(a + '\n' + repr(eval(a))
for a in self.args)
Relational Algebra
关系代数
The behavior implemented in pd.merge() is a subset of what is known as relational algebra, which is a formal set of
rules for manipulating relational data, and forms the conceptual foundation of operations available in most databases.
The strength of the relational algebra approach is that it proposes several primitive operations, which become the building
blocks of more complicated operations on any dataset. With this lexicon of fundamental operations implemented
efficiently in a database or other program, a wide range of fairly complicated composite operations can be performed.
实现的是我们称为关系代数的⼀个⼦集,关系代数是⼀系列操作关系数据的规则的集合,它构成了⼤部分数据库的数学基
础。关系代数的⼒量表现在它仅提出了⼏个基本的运算,这些基本运算成为了更多复杂运算的组成模块,能够应⽤到任何的数据集上。只
要在数据库或者其他程序中实现了这些最基本的运算,那么绝⼤部分的复杂组合运算都可以在上⾯实现。
pd.merge()
Pandas implements several of these fundamental building-blocks in the pd.merge() function and the related
join() method of Series and Dataframe s. As we will see, these let you efficiently link data from different
sources.
在 pd.merge() 函数中实现了⼀些上述所说的最基本的运算, Series 和 DataFrame 的 join ⽅法也实现了这部分基本运
算,你将会看到,这能让你很⾼效地从不同数据源组合数据。
Pandas
Categories of Joins
联表的分类
The pd.merge() function implements a number of types of joins: the one-to-one, many-to-one, and many-to-many
joins. All three types of joins are accessed via an identical call to the pd.merge() interface; the type of join performed
depends on the form of the input data. Here we will show simple examples of the three types of merges, and discuss
detailed options further below.
函数实现了⼏种不同类型的联表:⼀对⼀、多对⼀和多对多。所有三种类型的联表都可以通过 pd.merge() 函数调⽤来实
现;具体使⽤了哪种类型的联表取决于输⼊数据的格式。下⾯我们会展⽰⼀些简单的例⼦来说明三种联表类型,然后我们还会详细的讨论
它们的选项。
pd.merge()
One-to-one joins
⼀对⼀
Perhaps the simplest type of merge expresion is the one-to-one join, which is in many ways very similar to the columnwise concatenation seen in Combining Datasets: Concat & Append. As a concrete example, consider the following two
DataFrames which contain information on several employees in a company:
也许最简单的联表操作类型就是⼀对⼀连接,在很多⽅⾯,这种联表都和我们在组合数据集:Concat 和 Append中看到的按列进⾏数据集
连接很相似。下⾯定义两个 DataFrame 含有公司的⼀些员⼯信息作为⼀个具体的例⼦来说明:
In [2]: df1 = pd.DataFrame({'employee': ['Bob', 'Jake', 'Lisa', 'Sue'],
'group': ['Accounting', 'Engineering', 'Engineering', 'HR']})
df2 = pd.DataFrame({'employee': ['Lisa', 'Bob', 'Jake', 'Sue'],
'hire_date': [2004, 2008, 2012, 2014]})
display('df1', 'df2')
Out[2]:
df1
df2
employee
group
employee
hire_date
0
Bob
Accounting
0
Lisa
2004
1
Jake
Engineering
1
Bob
2008
2
Lisa
Engineering
2
Jake
2012
3
Sue
HR
3
Sue
2014
To combine this information into a single DataFrame , we can use the pd.merge() function:
要将这两个数据集组合成⼀个 DataFrame ,我们可以使⽤ pd.merge 函数:
In [3]: df3 = pd.merge(df1, df2)
df3
Out[3]:
employee
group
hire_date
0
Bob
Accounting
2008
1
Jake
Engineering
2012
2
Lisa
Engineering
2004
3
Sue
HR
2014
The pd.merge() function recognizes that each DataFrame has an "employee" column, and automatically joins
using this column as a key. The result of the merge is a new DataFrame that combines the information from the two
inputs. Notice that the order of entries in each column is not necessarily maintained: in this case, the order of the
"employee" column differs between df1 and df2 , and the pd.merge() function correctly accounts for this.
Additionally, keep in mind that the merge in general discards the index, except in the special case of merges by index
(see the left_index and right_index keywords, discussed momentarily).
函数会⾃动识别每个 DataFrame 都有"employee"列,因此会⾃动按照这个列作为键对双⽅进⾏合并。合并的结果是⼀个
新的
,其中的数据是两个输⼊数据集的联合。再注意到每个列的排列顺序在结果中并不⼀定保持了:在这个情况
下,
列的顺序在 df1 和 df2 中是不同的,⽽ pd.merge() 函数也正确的考虑到了这点。⽽且,要知道的是,合并的结果通
常会丢弃了原本的⾏索引标签,除⾮在合并时制定了⾏索引(参⻅我们⻢上会讨论到的 left_index 和 right_index 参数)。
pd.merge()
DataFrame
"employee"
Many-to-one joins
多对⼀
Many-to-one joins are joins in which one of the two key columns contains duplicate entries. For the many-to-one case,
the resulting DataFrame will preserve those duplicate entries as appropriate. Consider the following example of a
many-to-one join:
多对⼀联表的情况发⽣在两个数据集的关键字列上的其中⼀个含有重复数据的时候。在这种多对⼀的情况下,结果的 DataFrame 会正确
的保留那些重复的键值。看下⾯这个例⼦:
In [4]: df4 = pd.DataFrame({'group': ['Accounting', 'Engineering', 'HR'],
'supervisor': ['Carly', 'Guido', 'Steve']})
display('df3', 'df4', 'pd.merge(df3, df4)')
Out[4]:
df3
df4
employee
group
hire_date
0
Bob
Accounting
2008
1
Jake
Engineering
2
Lisa
3
Sue
pd.merge(df3, df4)
group
supervisor
employee
group
hire_date
supervisor
0
Accounting
Carly
0
Bob
Accounting
2008
Carly
2012
1
Engineering
Guido
1
Jake
Engineering
2012
Guido
Engineering
2004
2
HR
Steve
2
Lisa
Engineering
2004
Guido
HR
2014
3
Sue
HR
2014
Steve
The resulting DataFrame has an aditional column with the "supervisor" information, where the information is repeated
in one or more locations as required by the inputs.
结果的 DataFrame 多了⼀列 supervisor ,上⾯的数据也是按照 group 的重复情况进⾏重复的。
Many-to-many joins
多对多
Many-to-many joins are a bit confusing conceptually, but are nevertheless well defined. If the key column in both the left
and right array contains duplicates, then the result is a many-to-many merge. This will be perhaps most clear with a
concrete example. Consider the following, where we have a DataFrame showing one or more skills associated with a
particular group. By performing a many-to-many join, we can recover the skills associated with any individual person:
多对多联表在概念上有⼀点混乱,但实际上良好定义了的。如果左右的数据集在关键字列上都有重复数据,那么结果就是⼀个多对多的组
合。当然⽤⼀个具体的例⼦来说明是很有帮助的。⽐如下⾯的数据集 df5 存储的是⼀个岗位和其对应的技能。进⾏了多对多联表后,我们
可以获得每个员⼯对应的技能表:
In [5]: df5 = pd.DataFrame({'group': ['Accounting', 'Accounting',
'Engineering', 'Engineering', 'HR', 'HR'],
'skills': ['math', 'spreadsheets', 'coding', 'linux',
'spreadsheets', 'organization']})
display('df1', 'df5', "pd.merge(df1, df5)")
Out[5]:
df1
df5
employee
group
0
Bob
Accounting
1
Jake
2
3
pd.merge(df1, df5)
group
skills
employee
group
skills
0
Accounting
math
0
Bob
Accounting
math
Engineering
1
Accounting
spreadsheets
1
Bob
Accounting
spreadsheets
Lisa
Engineering
2
Engineering
coding
2
Jake
Engineering
coding
Sue
HR
3
Engineering
linux
3
Jake
Engineering
linux
4
HR
spreadsheets
4
Lisa
Engineering
coding
5
HR
organization
5
Lisa
Engineering
linux
6
Sue
HR
spreadsheets
7
Sue
HR
organization
These three types of joins can be used with other Pandas tools to implement a wide array of functionality. But in practice,
datasets are rarely as clean as the one we're working with here. In the following section we'll consider some of the
options provided by pd.merge() that enable you to tune how the join operations work.
这三种类型的连接可以和Pandas的其他⼯具联合使⽤,来实现很强⼤的功能。但是在实践中,数据集极少好像我们上⾯的例⼦那样⼲净。
在接下来的部分,我们会介绍 pd.merge() 提供的⼀些参数,能让你精细的对连接操作进⾏调整。
Specification of the Merge Key
指定合并关键字
We've already seen the default behavior of pd.merge() : it looks for one or more matching column names between the
two inputs, and uses this as the key. However, often the column names will not match so nicely, and pd.merge()
provides a variety of options for handling this.
上⾯我们看到 pd.merge() 的默认⾏为:它会在两个输⼊数据集中寻找⼀个或多个相同的列名,然后使⽤这(些)列作为合并的关键
字。然⽽,通常情况下,列名并不会这么匹配, pd.merge() 提供了⼀系列的参数来处理这种情况。
The on keyword
关键字参数
on
Most simply, you can explicitly specify the name of the key column using the on keyword, which takes a column name
or a list of column names:
最简单的,你可以使⽤ on 关键字参数明确指定合并使⽤的关键字列名,参数可以是⼀个列名或者⼀个列名的列表:
In [6]: display('df1', 'df2', "pd.merge(df1, df2, on='employee')")
Out[6]:
df1
df2
employee
group
0
Bob
Accounting
1
Jake
2
3
pd.merge(df1, df2, on='employee')
employee
hire_date
employee
group
hire_date
0
Lisa
2004
0
Bob
Accounting
2008
Engineering
1
Bob
2008
1
Jake
Engineering
2012
Lisa
Engineering
2
Jake
2012
2
Lisa
Engineering
2004
Sue
HR
3
Sue
2014
3
Sue
HR
2014
This option works only if both the left and right DataFrame s have the specified column name.
该参数仅在左右两个 DataFrame 都含有相同的指定列名的情况下有效。
The left_on and right_on keywords
和 right_on 关键字参数
left_on
At times you may wish to merge two datasets with different column names; for example, we may have a dataset in which
the employee name is labeled as "name" rather than "employee". In this case, we can use the left_on and
right_on keywords to specify the two column names:
在你希望使⽤不同列名来合并两个数据集的情况下;例如,我们有⼀个数据集,在它⾥⾯员⼯姓名的列名不是"employee"⽽是"name"。在
这种情况下,我们可以使⽤ left_on 和 right_on 关键字来分别指定两个列的名字:
In [7]: df3 = pd.DataFrame({'name': ['Bob', 'Jake', 'Lisa', 'Sue'],
'salary': [70000, 80000, 120000, 90000]})
display('df1', 'df3', 'pd.merge(df1, df3, left_on="employee", right_on="name")')
Out[7]:
df1
df3
employee
group
0
Bob
Accounting
1
Jake
2
3
pd.merge(df1, df3, left_on="employee", right_on="name")
name
salary
employee
group
name
salary
0
Bob
70000
0
Bob
Accounting
Bob
70000
Engineering
1
Jake
80000
1
Jake
Engineering
Jake
80000
Lisa
Engineering
2
Lisa
120000
2
Lisa
Engineering
Lisa
120000
Sue
HR
3
Sue
90000
3
Sue
HR
Sue
90000
The result has a redundant column that we can drop if desired–for example, by using the drop() method of
DataFrame s:
结果中有⼀个冗余的列,我们可以将改列移除,例如使⽤ DataFrame 的 drop() ⽅法:
In [8]: pd.merge(df1, df3, left_on="employee", right_on="name").drop('name', axis=1)
Out[8]:
employee
group
salary
0
Bob
Accounting
70000
1
Jake
Engineering
80000
2
Lisa
Engineering
120000
3
Sue
HR
90000
The left_index and right_index keywords
和 right_index 关键参数
left_index
Sometimes, rather than merging on a column, you would instead like to merge on an index. For example, your data might
look like this:
有时候,你不是需要按列进⾏合并,⽽是需要按照⾏索引进⾏合并。例如,将 df1 和 df2 数据集修改为如下情况:
In [9]: df1a = df1.set_index('employee')
df2a = df2.set_index('employee')
display('df1a', 'df2a')
Out[9]:
df1a
df2a
group
hire_date
employee
employee
Bob
Accounting
Lisa
2004
Jake
Engineering
Bob
2008
Lisa
Engineering
Jake
2012
Sue
HR
Sue
2014
You can use the index as the key for merging by specifying the left_index and/or right_index flags in
pd.merge() :
通过指定 left_index 和 right_index 标志参数,你可以将两个数据集按照⾏索引进⾏合并:
In [10]: display('df1a', 'df2a',
"pd.merge(df1a, df2a, left_index=True, right_index=True)")
Out[10]:
df1a
df2a
pd.merge(df1a, df2a, left_index=True, right_index=True)
group
hire_date
employee
employee
group
hire_date
employee
Bob
Accounting
Lisa
2004
Bob
Accounting
2008
Jake
Engineering
Bob
2008
Jake
Engineering
2012
Lisa
Engineering
Jake
2012
Lisa
Engineering
2004
Sue
HR
Sue
2014
Sue
HR
2014
For convenience, DataFrame s implement the join() method, which performs a merge that defaults to joining on
indices:
为了⽅便, DataFrame 实现了 join() ⽅法,默认按照⾏索引合并数据集:
In [11]: display('df1a', 'df2a', 'df1a.join(df2a)')
Out[11]:
df1a
df2a
df1a.join(df2a)
group
hire_date
employee
employee
group
hire_date
employee
Bob
Accounting
Lisa
2004
Bob
Accounting
2008
Jake
Engineering
Bob
2008
Jake
Engineering
2012
Lisa
Engineering
Jake
2012
Lisa
Engineering
2004
Sue
HR
Sue
2014
Sue
HR
2014
If you'd like to mix indices and columns, you can combine left_index with right_on or left_on with
right_index to get the desired behavior:
如果需要混合的进⾏⾏或列的合并,你可以通过混合指定 left_index 和 right_on 参数或者 left_on 和 right_index 参数来实
现:
In [12]: display('df1a', 'df3', "pd.merge(df1a, df3, left_index=True, right_on='name')")
Out[12]:
df1a
df3
group
employee
pd.merge(df1a, df3, left_index=True, right_on='name')
name
salary
0
Bob
70000
group
name
salary
0
Accounting
Bob
70000
Bob
Accounting
1
Jake
80000
1
Engineering
Jake
80000
Jake
Engineering
2
Lisa
120000
2
Engineering
Lisa
120000
Lisa
Engineering
3
Sue
90000
3
HR
Sue
90000
Sue
HR
All of these options also work with multiple indices and/or multiple columns; the interface for this behavior is very intuitive.
For more information on this, see the "Merge, Join, and Concatenate" section of the Pandas documentation.
所有上⾯的参数都能应⽤到多重⾏索引和/或多重列上;这个接⼝的定义是⾮常直观的。需要了解更多的信息,参⻅Pandas在线⽂
档"Merge、join和Concatenate"章节。
Specifying Set Arithmetic for Joins
指定合并的集合算术运算
In all the preceding examples we have glossed over one important consideration in performing a join: the type of set
arithmetic used in the join. This comes up when a value appears in one key column but not the other. Consider this
example:
在上⾯的例⼦中,我们都忽略了在进⾏数据集合并时⼀个重要的内容:合并时所使⽤的集合算术运算类型。这部分内容对于当⼀个数据集
的键值在另⼀个数据集中不存在时很有意义。看下例:
In [13]: df6 = pd.DataFrame({'name': ['Peter', 'Paul', 'Mary'],
'food': ['fish', 'beans', 'bread']},
columns=['name', 'food'])
df7 = pd.DataFrame({'name': ['Mary', 'Joseph'],
'drink': ['wine', 'beer']},
columns=['name', 'drink'])
display('df6', 'df7', 'pd.merge(df6, df7)')
Out[13]:
df6
df7
name
food
0
Peter
fish
1
Paul
beans
2
Mary
bread
pd.merge(df6, df7)
name
drink
0
Mary
wine
1
Joseph
beer
0
name
food
drink
Mary
bread
wine
Here we have merged two datasets that have only a single "name" entry in common: Mary. By default, the result contains
the intersection of the two sets of inputs; this is what is known as an inner join. We can specify this explicitly using the
how keyword, which defaults to "inner" :
上⾯我们合并的两个数据集在关键字列上只有⼀个"name"数据是共同的:Mary。默认情况下,结果会包含两个集合的交集;这被称为内连
接。我们显式的指定 how 关键字参数,它的默认值是 "inner" :
In [14]: pd.merge(df6, df7, how='inner')
Out[14]:
name
food
drink
Mary
bread
wine
0
Other options for the how keyword are 'outer' , 'left' , and 'right' . An outer join returns a join over the
union of the input columns, and fills in all missing values with NAs:
参数的其他选项包括 'outer' 、 'left' 和 'right' 。外连接outer会返回两个集合的并集,并将缺失的数据填充为Pandas的
值:
how
NA
In [15]: display('df6', 'df7', "pd.merge(df6, df7, how='outer')")
Out[15]:
df6
df7
name
food
0
Peter
fish
1
Paul
beans
2
Mary
bread
pd.merge(df6, df7, how='outer')
name
drink
0
Mary
wine
1
Joseph
beer
name
food
drink
0
Peter
fish
NaN
1
Paul
beans
NaN
2
Mary
bread
wine
3
Joseph
NaN
beer
The left join and right join return joins over the left entries and right entries, respectively. For example:
左连接left和右连接right返回的结果是包括所有的左边或右边集合。例如:
In [16]: display('df6', 'df7', "pd.merge(df6, df7, how='left')")
Out[16]:
df6
df7
name
food
0
Peter
fish
1
Paul
beans
2
Mary
bread
pd.merge(df6, df7, how='left')
name
drink
0
Mary
wine
1
Joseph
beer
name
food
drink
0
Peter
fish
NaN
1
Paul
beans
NaN
2
Mary
bread
wine
The output rows now correspond to the entries in the left input. Using how='right' works in a similar manner.
All of these options can be applied straightforwardly to any of the preceding join types.
结果中的⾏与左集合保持⼀致。使⽤ how='right' 结果会和右集合保持⼀致。
所有这些集合运算类型可以和前⾯的连接类型组合使⽤。
Overlapping Column Names: The suffixes Keyword
列名冲突: suffixes 关键字参数
Finally, you may end up in a case where your two input DataFrame s have conflicting column names. Consider this
example:
最后,你可能会碰到⼀种情况两个输⼊ DataFrame 有着冲突的列名。例如:
In [17]: df8 = pd.DataFrame({'name': ['Bob', 'Jake', 'Lisa', 'Sue'],
'rank': [1, 2, 3, 4]})
df9 = pd.DataFrame({'name': ['Bob', 'Jake', 'Lisa', 'Sue'],
'rank': [3, 1, 4, 2]})
display('df8', 'df9', 'pd.merge(df8, df9, on="name")')
Out[17]:
df8
df9
name
rank
0
Bob
1
1
Jake
2
3
pd.merge(df8, df9, on="name")
name
rank
name
rank_x
rank_y
0
Bob
3
0
Bob
1
3
2
1
Jake
1
1
Jake
2
1
Lisa
3
2
Lisa
4
2
Lisa
3
4
Sue
4
3
Sue
2
3
Sue
4
2
Because the output would have two conflicting column names, the merge function automatically appends a suffix _x or
_y to make the output columns unique. If these defaults are inappropriate, it is possible to specify a custom suffix using
the suffixes keyword:
因为结果可能会有两个相同的列名,发⽣冲突,merge函数会⾃动为这两个列添加 _x 和 _y 后缀,使得输出结果每个列名称唯⼀。如果
默认的后缀不是你希望的,可以使⽤ suffixes 关键字参数为输出列添加⾃定义的后缀:
In [18]: display('df8', 'df9', 'pd.merge(df8, df9, on="name", suffixes=["_L", "_R"])')
Out[18]:
df8
df9
name
rank
0
Bob
1
1
Jake
2
3
pd.merge(df8, df9, on="name", suffixes=["_L", "_R"])
name
rank
name
rank_L
rank_R
0
Bob
3
0
Bob
1
3
2
1
Jake
1
1
Jake
2
1
Lisa
3
2
Lisa
4
2
Lisa
3
4
Sue
4
3
Sue
2
3
Sue
4
2
These suffixes work in any of the possible join patterns, and work also if there are multiple overlapping columns.
这些后缀可以应⽤在所有的连接⽅式中,也可以在多个列冲突时使⽤。
For more information on these patterns, see Aggregation and Grouping where we dive a bit deeper into relational
algebra. Also see the Pandas "Merge, Join and Concatenate" documentation for further discussion of these topics.
需要了解更多知识,参⻅聚合与分组,我们会更加深⼊的介绍关系代数。也可以参⻅Pandas在线"Merge, Join 和 Concatenate"⽂档学习更
多内容。
Example: US States Data
例⼦:美国州数据
Merge and join operations come up most often when combining data from different sources. Here we will consider an
example of some data about US states and their populations. The data files can be found at
http://github.com/jakevdp/data-USstates/:
合并及联表操作在你处理多个不同数据来源时会经常出现。下⾯我们使⽤美国州及其⼈⼝数据作为例⼦来进⾏更加直观的说明。这些数据
⽂件可以在http://github.com/jakevdp/data-USstates/ 中找到:
如果你没有数据⽂件,可以使⽤下⾯的命令下载它们
In [19]: #
# !curl -O https://raw.githubusercontent.com/jakevdp/data-USstates/master/state-population.csv
# !curl -O https://raw.githubusercontent.com/jakevdp/data-USstates/master/state-areas.csv
# !curl -O https://raw.githubusercontent.com/jakevdp/data-USstates/master/state-abbrevs.csv
Let's take a look at the three datasets, using the Pandas read_csv() function:
下⾯我们来载⼊三个相关的数据⽂件,使⽤Pandas的 read_csv() 函数:
In [20]: pop = pd.read_csv('data/state-population.csv')
areas = pd.read_csv('data/state-areas.csv')
abbrevs = pd.read_csv('data/state-abbrevs.csv')
display('pop.head()', 'areas.head()', 'abbrevs.head()')
Out[20]:
pop.head()
areas.head()
state/region
ages
year
population
0
AL
under18
2012
1117489.0
1
AL
total
2012
2
AL
under18
3
AL
4
AL
abbrevs.head()
state
area (sq. mi)
state
abbreviation
0
Alabama
52423
0
Alabama
AL
4817528.0
1
Alaska
656425
1
Alaska
AK
2010
1130966.0
2
Arizona
114006
2
Arizona
AZ
total
2010
4785570.0
3
Arkansas
53182
3
Arkansas
AR
under18
2011
1125763.0
4
California
163707
4
California
CA
Given this information, say we want to compute a relatively straightforward result: rank US states and territories by their
2010 population density. We clearly have the data here to find this result, but we'll have to combine the datasets to find
the result.
有了数据之后,假如我们需要计算⼀个相对⾮常直接的结果:根据美国各州2010年⼈⼝密度进⾏排名。很显然我们有相关的数据,但是我
们需要合并数据集才能找到结果。
We'll start with a many-to-one merge that will give us the full state name within the population DataFrame . We want to
merge based on the state/region column of pop , and the abbreviation column of abbrevs . We'll use
how='outer' to make sure no data is thrown away due to mismatched labels.
我们先进⾏⼀个多对⼀的合并,将州全名和⼈⼝数据合并在⼀个 DataFrame 中。我们希望合并基于 pop 数据集的 state/region 列以
及 abbreviation 数据集的 abbrevs 列。使⽤ how='outer' 来保证合并过程中不会因为不匹配的标签⽽丢失任何数据。
In [21]: merged = pd.merge(pop, abbrevs, how='outer',
left_on='state/region', right_on='abbreviation')
merged = merged.drop('abbreviation', 1) #
merged.head()
移除冗余的列
Out[21]:
state/region
ages
year
population
state
0
AL
under18
2012
1117489.0
Alabama
1
AL
total
2012
4817528.0
Alabama
2
AL
under18
2010
1130966.0
Alabama
3
AL
total
2010
4785570.0
Alabama
4
AL
under18
2011
1125763.0
Alabama
Let's double-check whether there were any mismatches here, which we can do by looking for rows with nulls:
让我们检查结果中是否有不匹配的情况,通过在数据集中寻找空值来查看:
In [22]: merged.isnull().any()
Out[22]: state/region
ages
year
population
state
dtype: bool
False
False
False
True
True
Some of the population info is null; let's figure out which these are!
⼀些⼈⼝ population 数据是空的;再来看看是哪些。
In [23]: merged[merged['population'].isnull()].head()
Out[23]:
state/region
ages
year
population
state
2448
PR
under18
1990
NaN
NaN
2449
PR
total
1990
NaN
NaN
2450
PR
total
1991
NaN
NaN
2451
PR
under18
1991
NaN
NaN
2452
PR
total
1993
NaN
NaN
It appears that all the null population values are from Puerto Rico prior to the year 2000; this is likely due to this data not
being available from the original source.
发现所有空的⼈⼝数据都是2000年前波多黎各的;这可能因为数据来源本来就没有这些数据造成的。
More importantly, we see also that some of the new state entries are also null, which means that there was no
corresponding entry in the abbrevs key! Let's figure out which regions lack this match:
更重要的是,我们发现⼀些新的州 state 的数据也是空的,这意味着 abbrevs 列中不存在这些州的简称。再看看是哪些州有这种情况:
In [24]: merged.loc[merged['state'].isnull(), 'state/region'].unique()
Out[24]: array(['PR', 'USA'], dtype=object)
We can quickly infer the issue: our population data includes entries for Puerto Rico (PR) and the United States as a
whole (USA), while these entries do not appear in the state abbreviation key. We can fix these quickly by filling in
appropriate entries:
从上⾯的结果很容易发现:⼈⼝数据集中包括波多黎各(PR)和全美国(USA)的数据,⽽州简称数据集中却没有这两者数据。通过填充
相应的数据可以很快解决这个问题:
In [25]: merged.loc[merged['state/region'] == 'PR', 'state'] = 'Puerto Rico'
merged.loc[merged['state/region'] == 'USA', 'state'] = 'United States'
merged.isnull().any()
Out[25]: state/region
ages
year
population
state
dtype: bool
False
False
False
True
False
No more nulls in the state column: we're all set!
state
列没有空值了:我们准备好了。
Now we can merge the result with the area data using a similar procedure. Examining our results, we will want to join on
the state column in both:
下⾯我们可以将上⾯的结果数据集和⾯积数据集进⾏合并。研究两个数据集发现,我们需要在 state 列上进⾏数据集合并操作:
In [26]: final = pd.merge(merged, areas, on='state', how='left')
final.head()
Out[26]:
state/region
ages
year
population
state
area (sq. mi)
0
AL
under18
2012
1117489.0
Alabama
52423.0
1
AL
total
2012
4817528.0
Alabama
52423.0
2
AL
under18
2010
1130966.0
Alabama
52423.0
3
AL
total
2010
4785570.0
Alabama
52423.0
4
AL
under18
2011
1125763.0
Alabama
52423.0
Again, let's check for nulls to see if there were any mismatches:
再⼀次,我们检查⼀次空值,来看是否存在不匹配的情况:
In [27]: final.isnull().any()
Out[27]: state/region
ages
year
population
state
area (sq. mi)
dtype: bool
False
False
False
True
False
True
There are nulls in the area column; we can take a look to see which regions were ignored here:
⾯积 area 列有空值;我们看看是哪⾥出现的:
In [28]: final['state'][final['area (sq. mi)'].isnull()].unique()
Out[28]: array(['United States'], dtype=object)
We see that our areas DataFrame does not contain the area of the United States as a whole. We could insert the
appropriate value (using the sum of all state areas, for instance), but in this case we'll just drop the null values because
the population density of the entire United States is not relevant to our current discussion:
结果显⽰⾯积数据集不包括整个美国的⾯积。我们可以为这个空值插⼊正确的值(使⽤所有州的⾯积数据之和),但是这个例⼦中我们只
需要简单地移除空值数据即可,因为全美国的⼈⼝密度数据与我们前⾯的问题⽆关:
In [29]: final.dropna(inplace=True)
final.head()
Out[29]:
state/region
ages
year
population
state
area (sq. mi)
0
AL
under18
2012
1117489.0
Alabama
52423.0
1
AL
total
2012
4817528.0
Alabama
52423.0
2
AL
under18
2010
1130966.0
Alabama
52423.0
3
AL
total
2010
4785570.0
Alabama
52423.0
4
AL
under18
2011
1125763.0
Alabama
52423.0
Now we have all the data we need. To answer the question of interest, let's first select the portion of the data
corresponding with the year 2000, and the total population. We'll use the query() function to do this quickly (this
requires the numexpr package to be installed; see High-Performance Pandas: eval() and query() ):
现在我们需要数据都已经准备好了。要回答前⾯那个问题,⾸先要选择出2010年相应的部分数据集以及不分年龄的全体⼈⼝数。我们使⽤
query() 函数来快速完成这项任务(这需要安装 numexpr 包,参⻅⾼性能Pandas: eval() 和 query() ):
In [30]: data2010 = final.query("year == 2010 & ages == 'total'")
data2010.head()
Out[30]:
state/region
ages
year
population
state
area (sq. mi)
3
AL
total
2010
4785570.0
Alabama
52423.0
91
AK
total
2010
713868.0
Alaska
656425.0
101
AZ
total
2010
6408790.0
Arizona
114006.0
189
AR
total
2010
2922280.0
Arkansas
53182.0
197
CA
total
2010
37333601.0
California
163707.0
Now let's compute the population density and display it in order. We'll start by re-indexing our data on the state, and then
compute the result:
下⾯我们可以计算⼈⼝密度并排序输出了。我们现将数据集按照 state 进⾏重新索引,然后计算结果:
In [31]: data2010.set_index('state', inplace=True)
density = data2010['population'] / data2010['area (sq. mi)']
In [32]: density.sort_values(ascending=False, inplace=True)
density.head()
Out[32]: state
District of Columbia
Puerto Rico
New Jersey
Rhode Island
Connecticut
dtype: float64
8898.897059
1058.665149
1009.253268
681.339159
645.600649
The result is a ranking of US states plus Washington, DC, and Puerto Rico in order of their 2010 population density, in
residents per square mile. We can see that by far the densest region in this dataset is Washington, DC (i.e., the District of
Columbia); among states, the densest is New Jersey.
结果是美国州根据2010年⼈⼝密度的排名,包括华盛顿特区和波多黎各,数据是每平⽅英⾥的居住⼈数。结果显⽰⼈⼝密度最稠密的地区
是华盛顿特区(表中的the District of Columbia);在其他的州中,⼈⼝密度最⼤的是新泽西。
We can also check the end of the list:
我们也可以查看结果的最后部分:
In [33]: density.tail()
Out[33]: state
South Dakota
North Dakota
Montana
Wyoming
Alaska
dtype: float64
10.583512
9.537565
6.736171
5.768079
1.087509
We see that the least dense state, by far, is Alaska, averaging slightly over one resident per square mile.
结果显⽰密度最⼩的州,阿拉斯加,平均每平⽅英⾥略⼤于1个居⺠。
This type of messy data merging is a common task when trying to answer questions using real-world data sources. I
hope that this example has given you an idea of the ways you can combine tools we've covered in order to gain insight
from your data!
当使⽤真实世界数据回答这种问题的时候,这种数据集的合并是很常⻅的任务。作者希望这个例⼦能为你展⽰了Pandas数据集合并的⼯具
的使⽤,并能在你的数据集中应⽤这些⽅法。
<
组合数据集:Concat 和 Append | ⽬录 | 聚合与分组 >
Open in Colab
<
组合数据集:Merge 和 Join | ⽬录 | 数据透视表 >
Open in Colab
Aggregation and Grouping
聚合与分组
An essential piece of analysis of large data is efficient summarization: computing aggregations like sum() , mean() ,
median() , min() , and max() , in which a single number gives insight into the nature of a potentially large dataset.
In this section, we'll explore aggregations in Pandas, from simple operations akin to what we've seen on NumPy arrays,
to more sophisticated operations based on the concept of a groupby .
对于⼀个⼤数据集进⾏分析的关键部分是使⽤有效的概括:对数据集进⾏ sum() 、 mean() 、 median() 、 min() 和 max() 聚合运
算,这些运算的结果就可能可以给出⼤数据集的⼀些内在特征。在本节中,我们会探讨Pandas中的聚合,从我们已经在NumPy数组中进
⾏过的那些简单的操作,直到基于分组 groupby 概念进⾏的更复杂的操作。
For convenience, we'll use the same display magic function that we've seen in previous sections:
⽅便起⻅,我们还是使⽤与前两节同样的 display 类来展⽰多个数据集:
In [1]: import numpy as np
import pandas as pd
class display(object):
"""Display HTML representation of multiple objects"""
template = """<div style="float: left; padding: 10px;">
<p style='font-family:"Courier New", Courier, monospace'>{0}</p>{1}
</div>"""
def __init__(self, *args):
self.args = args
def _repr_html_(self):
return '\n'.join(self.template.format(a, eval(a)._repr_html_())
for a in self.args)
def __repr__(self):
return '\n\n'.join(a + '\n' + repr(eval(a))
for a in self.args)
Planets Data
⾏星数据
Here we will use the Planets dataset, available via the Seaborn package (see Visualization With Seaborn). It gives
information on planets that astronomers have discovered around other stars (known as extrasolar planets or exoplanets
for short). It can be downloaded with a simple Seaborn command:
这⾥我们会使⽤Seaborn包提供的⾏星数据(参⻅使⽤Seaborn进⾏可视化)。这个数据集提供了天⽂学家发现的其他恒星的⾏星的数据
(被称为太阳系外⾏星)。数据集可以简单的使⽤⼀个Seaborn命令来下载:
In [2]: import seaborn as sns
planets = sns.load_dataset('planets')
planets.shape
Out[2]: (1035, 6)
In [3]: planets.head()
Out[3]:
method
number
orbital_period
mass
distance
year
0
Radial Velocity
1
269.300
7.10
77.40
2006
1
Radial Velocity
1
874.774
2.21
56.95
2008
2
Radial Velocity
1
763.000
2.60
19.84
2011
3
Radial Velocity
1
326.030
19.40
110.62
2007
4
Radial Velocity
1
516.220
10.50
119.47
2009
This has some details on the 1,000+ extrasolar planets discovered up to 2014.
直到2014年已经有超过1000个太阳系外⾏星的数据。
Simple Aggregation in Pandas
在Pandas中进⾏简单聚合
Earlier, we explored some of the data aggregations available for NumPy arrays ("Aggregations: Min, Max, and Everything
In Between"). As with a one-dimensional NumPy array, for a Pandas Series the aggregates return a single value:
上⼀章中,我们已经介绍了NumPy数组的数据聚合操作(聚合:Min, Max, 以及其他)。正如⼀维NumPy数组,Pandas的 Series 的聚
合结果是⼀个标量:
In [4]: rng = np.random.RandomState(42)
ser = pd.Series(rng.rand(5))
ser
Out[4]: 0
0.374540
1
0.950714
2
0.731994
3
0.598658
4
0.156019
dtype: float64
In [5]: ser.sum()
Out[5]: 2.811925491708157
In [6]: ser.mean()
Out[6]: 0.5623850983416314
For a DataFrame , by default the aggregates return results within each column:
对于 DataFrame 来说,默认情况下是每个列进⾏聚合的结果:
In [7]: df = pd.DataFrame({'A': rng.rand(5),
'B': rng.rand(5)})
df
Out[7]:
A
B
0
0.155995
0.020584
1
0.058084
0.969910
2
0.866176
0.832443
3
0.601115
0.212339
4
0.708073
0.181825
In [8]: df.mean()
Out[8]: A
0.477888
B
0.443420
dtype: float64
By specifying the axis argument, you can instead aggregate within each row:
通过指定 axis 参数,可以为每⼀⾏进⾏聚合操作:
In [9]: df.mean(axis='columns')
Out[9]: 0
0.088290
1
0.513997
2
0.849309
3
0.406727
4
0.444949
dtype: float64
Pandas Series and DataFrame s include all of the common aggregates mentioned in Aggregations: Min, Max, and
Everything In Between; in addition, there is a convenience method describe() that computes several common
aggregates for each column and returns the result. Let's use this on the Planets data, for now dropping rows with missing
values:
的 Series 和 DataFrame 包括了所有我们在聚合:Min, Max, 以及其他中介绍过的通⽤聚合操作;⽽且Pandas还提供了很⽅便
的 describe() 可以⽤来对每个列计算这些通⽤的聚合结果。让我们在⾏星数据集上使⽤这个函数,暂时先移除含有空值的⾏:
Pandas
In [10]: planets.dropna().describe()
Out[10]:
number
orbital_period
mass
distance
year
count
498.00000
498.000000
498.000000
498.000000
498.000000
mean
1.73494
835.778671
2.509320
52.068213
2007.377510
std
1.17572
1469.128259
3.636274
46.596041
4.167284
min
1.00000
1.328300
0.003600
1.350000
1989.000000
25%
1.00000
38.272250
0.212500
24.497500
2005.000000
50%
1.00000
357.000000
1.245000
39.940000
2009.000000
75%
2.00000
999.600000
2.867500
59.332500
2011.000000
max
6.00000
17337.500000
25.000000
354.000000
2014.000000
This can be a useful way to begin understanding the overall properties of a dataset. For example, we see in the year
column that although exoplanets were discovered as far back as 1989, half of all known expolanets were not discovered
until 2010 or after. This is largely thanks to the Kepler mission, which is a space-based telescope specifically designed for
finding eclipsing planets around other stars.
对于开始理解数据集的整体情况来说,这是⼀个⾮常有⽤的⽅法。例如,在发现年份 year 列上,结果显⽰,虽然第⼀颗太阳系外⾏星是
1989年发现的,但是⼀半的⾏星直到2010年以后才被发现的。这多亏了开普勒Kepler计划,它是⼀个太空望远镜,专⻔设计⽤来寻找其他
恒星的椭圆轨道⾏星的。
The following table summarizes some other built-in Pandas aggregations:
下表概括了Pandas內建的聚合操作:
聚合函数
描述
count()
元素个数
first() , last() 第⼀个和最后⼀个元素
mean() , median()
平均值和中位数
min() , max()
最⼩和最⼤值
std() , var()
标准差和⽅差
mad()
平均绝对离差
prod()
所有元素的乘积
sum()
所有元素的总和
These are all methods of DataFrame and Series objects.
它们都是 DataFrame 和 Series 对象的⽅法。
To go deeper into the data, however, simple aggregates are often not enough. The next level of data summarization is the
groupby operation, which allows you to quickly and efficiently compute aggregates on subsets of data.
然⽽要深⼊了解数据,简单的聚合经常是不够的。 groupby 操作为我们提供更⾼层次的概括功能,通过它能很快速和有效地计算⼦数据
集的聚合数据。
GroupBy: Split, Apply, Combine
分组:拆分、应⽤、组合
Simple aggregations can give you a flavor of your dataset, but often we would prefer to aggregate conditionally on some
label or index: this is implemented in the so-called groupby operation. The name "group by" comes from a command in
the SQL database language, but it is perhaps more illuminative to think of it in the terms first coined by Hadley Wickham
of Rstats fame: split, apply, combine.
简单的聚合可以提供数据集的基础特征,但是通常我们更希望依据⼀些标签或索引条件进⾏聚合操作:这可以通过 groupby 操作实
现。"group by"的名称来⾃于SQL,但是将它想成是由Hadley Wickham⾸先创造的R数据统计术语会更合适:拆分、应⽤、组合。
Split, apply, combine
拆分、应⽤、组合
A canonical example of this split-apply-combine operation, where the "apply" is a summation aggregation, is illustrated in
this figure:
作为拆分-应⽤-组合操作的⼀个典型例⼦,下图展⽰了当进⾏求和的“应⽤”聚合操作时的情况:
figure source in Appendix
附录:⽣成图像的源代码
This makes clear what the groupby accomplishes:
The split step involves breaking up and grouping a DataFrame depending on the value of the specified key.
The apply step involves computing some function, usually an aggregate, transformation, or filtering, within the
individual groups.
The combine step merges the results of these operations into an output array.
上图很清晰地展⽰了 groupby 完成的⼯作:
拆分split步骤表⽰按照指定键上的值对 DataFrame 进⾏拆分和分组的功能。
应⽤apply步骤表⽰在每个独⽴的分组上调⽤某些函数进⾏计算,通常是聚合、转换或过滤。
组合combine步骤将上述计算的结果重新合并在⼀起输出。
While this could certainly be done manually using some combination of the masking, aggregation, and merging
commands covered earlier, an important realization is that the intermediate splits do not need to be explicitly instantiated.
Rather, the GroupBy can (often) do this in a single pass over the data, updating the sum, mean, count, min, or other
aggregate for each group along the way. The power of the GroupBy is that it abstracts away these steps: the user need
not think about how the computation is done under the hood, but rather thinks about the operation as a whole.
虽然这可以通过将前⾯介绍过的遮盖、聚合和组合指令组合在⼀起来实现, groupby 的⼀个重要的实现是拆分的中间结果不需要真正的
创建出来。⽽且, groupby (通常)可以在⼀次过程中处理完所有的数据分组的总和、平均值、计数、最⼩是或其他聚合操作。
groupby 的强⼤在于它将这些步骤抽象了出来:⽤⼾不需要思考这些计算是如何进⾏的,只需要认为这些操作是⼀个整体。
As a concrete example, let's take a look at using Pandas for the computation shown in this diagram. We'll start by
creating the input DataFrame :
作为⼀个具体的例⼦,我们来看⼀下使⽤Pandas来实现上⾯的这些计算,⾸先创建⼀个输⼊ DataFrame :
In [11]: df = pd.DataFrame({'key': ['A', 'B', 'C', 'A', 'B', 'C'],
'data': range(6)}, columns=['key', 'data'])
df
Out[11]:
key
data
0
A
0
1
B
1
2
C
2
3
A
3
4
B
4
5
C
5
The most basic split-apply-combine operation can be computed with the groupby() method of DataFrame s,
passing the name of the desired key column:
最基础的拆分-应⽤-组合操作可以使⽤ DataFrame 的 groupby() ⽅法来实现,⽅法中传递作为键来运算的列名:
In [12]: df.groupby('key')
Out[12]: <pandas.core.groupby.generic.DataFrameGroupBy object at 0x7fd196fe1d30>
Notice that what is returned is not a set of DataFrame s, but a DataFrameGroupBy object. This object is where the
magic is: you can think of it as a special view of the DataFrame , which is poised to dig into the groups but does no
actual computation until the aggregation is applied. This "lazy evaluation" approach means that common aggregates can
be implemented very efficiently in a way that is almost transparent to the user.
上⾯运⾏的结果不是⼀个 DataFrame ,⽽是⼀个 DataFrameGroupBy 对象。这个对象就是上述步骤魔术的所在:你可以认为它是
DataFrame 对象的⼀个特殊的视图,使⽤它可以很容易的研究分组的数据,但是除⾮聚合操作发⽣,否则它不会进⾏真实的运算。这
种“懒运算”的⽅式意味着通⽤的聚合可以实现得⾮常的⾼效,⽽对⽤⼾来说⼏乎是透明的。
To produce a result, we can apply an aggregate to this DataFrameGroupBy object, which will perform the appropriate
apply/combine steps to produce the desired result:
要产⽣结果,我们可以将⼀个聚合操作应⽤到该 DataFrameGroupBy 对象上,这样就会在分组上执⾏应⽤/组合的步骤,并产⽣需要的结
果:
In [13]: df.groupby('key').sum()
Out[13]:
data
key
A
3
B
5
C
7
The sum() method is just one possibility here; you can apply virtually any common Pandas or NumPy aggregation
function, as well as virtually any valid DataFrame operation, as we will see in the following discussion.
⽅法仅是其中⼀个可能的操作;你可以在这⾥应⽤⼏乎所有的Pandas或NumPy的通⽤聚合函数,也可以应⽤集合所有正确的
操作,我们在下⾯⻢上就会看到。
sum()
DataFrame
The GroupBy object
对象
GroupBy
The GroupBy object is a very flexible abstraction. In many ways, you can simply treat it as if it's a collection of
DataFrame s, and it does the difficult things under the hood. Let's see some examples using the Planets data.
对象是⼀个很灵活的抽象。在很多情况下,你可以将它简单的看成 DataFrame 的集合,它在底层做了很多复杂的⼯作。我们
⽤⾏星数据集来看⼏个例⼦。
GroupBy
Perhaps the most important operations made available by a GroupBy are aggregate, filter, transform, and apply. We'll
discuss each of these more fully in "Aggregate, Filter, Transform, Apply", but before that let's introduce some of the other
functionality that can be used with the basic GroupBy operation.
也许对 GroupBy 对象最重要的操作是聚合、过滤、转换和应⽤。我们会在聚合、过滤、转换、应⽤中逐个介绍它们,在这之前⾸先介绍
⼀些其他⽤于 GroupBy 对象的基础操作。
Column indexing
列索引
The GroupBy object supports column indexing in the same way as the DataFrame , and returns a modified
GroupBy object. For example:
GroupBy
对象⽀持列索引,与 DataFrame 相同,返回的是修改后的 GroupBy 对象。例如:
In [14]: planets.groupby('method')
Out[14]: <pandas.core.groupby.generic.DataFrameGroupBy object at 0x7fd196fa0358>
In [15]: planets.groupby('method')['orbital_period']
Out[15]: <pandas.core.groupby.generic.SeriesGroupBy object at 0x7fd19700ddd8>
Here we've selected a particular Series group from the original DataFrame group by reference to its column name.
As with the GroupBy object, no computation is done until we call some aggregate on the object:
上例中我们在原始的 DataFrame 中选择了特定的 Series ,这个 Series 是按照提供的列名进⾏分组的。当然, GroupBy 对象在调
⽤聚合操作之前是不会进⾏计算的:
In [16]: planets.groupby('method')['orbital_period'].median()
Out[16]: method
Astrometry
631.180000
Eclipse Timing Variations
4343.500000
Imaging
27500.000000
Microlensing
3300.000000
Orbital Brightness Modulation
0.342887
Pulsar Timing
66.541900
Pulsation Timing Variations
1170.000000
Radial Velocity
360.200000
Transit
5.714932
Transit Timing Variations
57.011000
Name: orbital_period, dtype: float64
This gives an idea of the general scale of orbital periods (in days) that each method is sensitive to.
结果给出了⼀个不同测量⽅法对公转周期进⾏测量的⼤概范围。
Iteration over groups
在分组上进⾏迭代
The GroupBy object supports direct iteration over the groups, returning each group as a Series or DataFrame :
GroupBy
对象⽀持在分组上直接进⾏迭代,每次迭代返回分组的⼀个 Series 或 DataFrame 对象:
In [17]: for (method, group) in planets.groupby('method'):
print("{0:30s} shape={1}".format(method, group.shape))
Astrometry
Eclipse Timing Variations
Imaging
Microlensing
Orbital Brightness Modulation
Pulsar Timing
Pulsation Timing Variations
Radial Velocity
Transit
Transit Timing Variations
shape=(2, 6)
shape=(9, 6)
shape=(38, 6)
shape=(23, 6)
shape=(3, 6)
shape=(5, 6)
shape=(1, 6)
shape=(553, 6)
shape=(397, 6)
shape=(4, 6)
This can be useful for doing certain things manually, though it is often much faster to use the built-in apply
functionality, which we will discuss momentarily.
这种做法在某些需要⼿动实现的情况下很有⽤,虽然通常来说使⽤內建的 apply 函数会快很多,我们⻢上会介绍到 apply 函数。
Dispatch methods
扩展⽅法
Through some Python class magic, any method not explicitly implemented by the GroupBy object will be passed
through and called on the groups, whether they are DataFrame or Series objects. For example, you can use the
describe() method of DataFrame s to perform a set of aggregations that describe each group in the data:
通过⼀些Python⾯向对象的魔术技巧,任何⾮显式定义在 GroupBy 对象上的⽅法,⽆论是 DataFrame 还是 Series 对象的,都可以给
分组来调⽤。例如,你可以在数据分组上调⽤ DataFrame 的 describe() ⽅法,对所有分组进⾏通⽤的聚合运算:
译者注:作者下⾯代码多加了 unstack() ⽅法,应该是笔误。
In [18]: planets.groupby('method')['year'].describe()
Out[18]:
count
mean
std
min
25%
50%
75%
max
method
Astrometry
2.0
2011.500000
2.121320
2010.0
2010.75
2011.5
2012.25
2013.0
Eclipse Timing Variations
9.0
2010.000000
1.414214
2008.0
2009.00
2010.0
2011.00
2012.0
Imaging
38.0
2009.131579
2.781901
2004.0
2008.00
2009.0
2011.00
2013.0
Microlensing
23.0
2009.782609
2.859697
2004.0
2008.00
2010.0
2012.00
2013.0
Orbital Brightness Modulation
3.0
2011.666667
1.154701
2011.0
2011.00
2011.0
2012.00
2013.0
Pulsar Timing
5.0
1998.400000
8.384510
1992.0
1992.00
1994.0
2003.00
2011.0
Pulsation Timing Variations
1.0
2007.000000
NaN
2007.0
2007.00
2007.0
2007.00
2007.0
Radial Velocity
553.0
2007.518987
4.249052
1989.0
2005.00
2009.0
2011.00
2014.0
Transit
397.0
2011.236776
2.077867
2002.0
2010.00
2012.0
2013.00
2014.0
Transit Timing Variations
4.0
2012.500000
1.290994
2011.0
2011.75
2012.5
2013.25
2014.0
Looking at this table helps us to better understand the data: for example, the vast majority of planets have been
discovered by the Radial Velocity and Transit methods, though the latter only became common (due to new, more
accurate telescopes) in the last decade. The newest methods seem to be Transit Timing Variation and Orbital Brightness
Modulation, which were not used to discover a new planet until 2011.
查看上表,能帮助我们更好的理解数据:例如,发现⾏星最多的⽅法是径向速度和凌⽇法,虽然后者是近⼗年才变得普遍(因为新的更精
准的望远镜的作⽤)。最新的⽅法应该是凌⽇时间变分法和轨道亮度调制法,它们直⾄2011年才开始发现新的⾏星。
This is just one example of the utility of dispatch methods. Notice that they are applied to each individual group, and the
results are then combined within GroupBy and returned. Again, any valid DataFrame / Series method can be used
on the corresponding GroupBy object, which allows for some very flexible and powerful operations!
这只是⼀个使⽤扩展⽅法的例⼦。你需要知道的是这些⽅法会被应⽤到每⼀个独⽴的分组上,然后计算得到的结果会在 GroupBy 对象中
合并并返回。再次提⽰,任何正确的 DataFrame 或 Series ⽅法都能在相应的 GroupBy 对象上使⽤,这种扩展⽅法的⽅式提供了⾮常
灵活及强⼤的操作。
Aggregate, filter, transform, apply
聚合、过滤、转换、应⽤
The preceding discussion focused on aggregation for the combine operation, but there are more options available. In
particular, GroupBy objects have aggregate() , filter() , transform() , and apply() methods that
efficiently implement a variety of useful operations before combining the grouped data.
前⾯的讨论聚焦在组合操作相应的聚合函数上,但实际上还有更多的可能选项。特别是 GroupBy 对象有 aggregate() 、
filter() 、 transfrom 和 apply() ⽅法,它们能在组合分组数据之前有效地实现⼤量有⽤的操作。
For the purpose of the following subsections, we'll use this DataFrame :
对于下⾯的部分内容,我们将使⽤下述的 DataFrame :
In [19]: rng = np.random.RandomState(0)
df = pd.DataFrame({'key': ['A', 'B', 'C', 'A', 'B', 'C'],
'data1': range(6),
'data2': rng.randint(0, 10, 6)},
columns = ['key', 'data1', 'data2'])
df
Out[19]:
key
data1
data2
0
A
0
5
1
B
1
0
2
C
2
3
3
A
3
3
4
B
4
7
5
C
5
9
Aggregation
聚合
We're now familiar with GroupBy aggregations with sum() , median() , and the like, but the aggregate()
method allows for even more flexibility. It can take a string, a function, or a list thereof, and compute all the aggregates at
once. Here is a quick example combining all these:
我们已经熟悉了 GroupBy 使⽤ sum() 、 median() 等⽅法进⾏聚合的做法,但是 aggregate() ⽅法能提供更多的灵活性。它能接受
字符串、函数或者⼀个列表,然后⼀次性计算出所有的聚合结果。下⾯是⼀个简单的例⼦:
In [20]: df.groupby('key').aggregate(['min', np.median, max])
Out[20]:
data1
data2
min
median
max
min
median
max
A
0
1.5
3
3
4.0
5
B
1
2.5
4
0
3.5
7
C
2
3.5
5
3
6.0
9
key
Another useful pattern is to pass a dictionary mapping column names to operations to be applied on that column:
还可以将⼀个字典,⾥⾯是列名与操作的对应关系,传递给 aggregate() 来进⾏⼀次性的聚合运算:
In [21]: df.groupby('key').aggregate({'data1': 'min',
'data2': 'max'})
Out[21]:
data1
data2
A
0
5
B
1
7
C
2
9
key
Filtering
过滤
A filtering operation allows you to drop data based on the group properties. For example, we might want to keep all
groups in which the standard deviation is larger than some critical value:
过滤操作能在分组数据上移除⼀些你不需要的数据。例如,我们可能希望保留标准差⼤于某个阈值的所有的分组:
译者注:你可以认为 filter() 类似于SQL中的HAVING。
In [22]: def filter_func(x):
return x['data2'].std() > 4
display('df', "df.groupby('key').std()", "df.groupby('key').filter(filter_func)")
Out[22]:
df
df.groupby('key').std()
key
data1
data2
data1
0
A
0
5
key
1
B
1
0
A
2.12132
2
C
2
3
B
3
A
3
3
C
4
B
4
7
5
C
5
9
df.groupby('key').filter(filter_func)
data2
key
data1
data2
1
B
1
0
1.414214
2
C
2
3
2.12132
4.949747
4
B
4
7
2.12132
4.242641
5
C
5
9
The filter function should return a Boolean value specifying whether the group passes the filtering. Here because group A
does not have a standard deviation greater than 4, it is dropped from the result.
⽤来进⾏过滤的函数必须返回⼀个布尔值,表⽰分组是否能够通过过滤条件。上例中A分组的标准差不是⼤于4,因此整个分组在结果中被
移除了。
Transformation
转换
While aggregation must return a reduced version of the data, transformation can return some transformed version of the
full data to recombine. For such a transformation, the output is the same shape as the input. A common example is to
center the data by subtracting the group-wise mean:
聚合返回的是分组简化后的数据集,⽽转换可以返回完整数据转换后并重新合并的数据集。因此转换操作的结果和输⼊数据集具有相同的
形状。⼀个通⽤例⼦是将整个数据集通过减去每个分组的平均值进⾏中⼼化:
In [23]: df.groupby('key').transform(lambda x: x - x.mean())
Out[23]:
data1
data2
0
-1.5
1.0
1
-1.5
-3.5
2
-1.5
-3.0
3
1.5
-1.0
4
1.5
3.5
5
1.5
3.0
The apply() method
应⽤
The apply() method lets you apply an arbitrary function to the group results. The function should take a
DataFrame , and return either a Pandas object (e.g., DataFrame , Series ) or a scalar; the combine operation will
be tailored to the type of output returned.
⽅法能让你将分组的结果应⽤到任意的函数上。该函数必须接受⼀个 DataFrame 参数,返回⼀个Pandas对象(如
、 Series )或者⼀个标量;组合操作会根据返回的类型进⾏适配。
apply()
DataFrame
For example, here is an apply() that normalizes the first column by the sum of the second:
例如,下⾯采⽤ apply() 使⽤ data2 的分组总和来正则化 data1 的值:
In [24]: def norm_by_data2(x):
# x is a DataFrame of group values
x['data1'] /= x['data2'].sum()
return x
display('df', "df.groupby('key').apply(norm_by_data2)")
Out[24]:
df
df.groupby('key').apply(norm_by_data2)
key
data1
data2
key
data1
data2
0
A
0
5
0
A
0.000000
5
1
B
1
0
1
B
0.142857
0
2
C
2
3
2
C
0.166667
3
3
A
3
3
3
A
0.375000
3
4
B
4
7
4
B
0.571429
7
5
C
5
9
5
C
0.416667
9
apply() within a GroupBy is quite flexible: the only criterion is that the function takes a DataFrame and returns a
Pandas object or scalar; what you do in the middle is up to you!
对象的 apply() ⽅法是⾮常灵活的:唯⼀的限制就是应⽤的函数要接受⼀个 DataFrame 参数并且返回⼀个Pandas对象或者
标量;函数体内做什么⼯作完全是⾃定义的。
GroupBy
Specifying the split key
指定拆分键
In the simple examples presented before, we split the DataFrame on a single column name. This is just one of many
options by which the groups can be defined, and we'll go through some other options for group specification here.
在前⾯的简单例⼦中,我们使⽤⼀个列名对 DataFrame 进⾏拆分。这只是分组的众多⽅式的其中之⼀,我们下⾯继续探讨其他的选项。
A list, array, series, or index providing the grouping keys
使⽤列表、数组、序列或索引指定分组键
The key can be any series or list with a length matching that of the DataFrame . For example:
分组使⽤的键可以使任何的序列或列表,只要⻓度和 DataFrame 的⻓度互相匹配即可。例如:
In [25]: L = [0, 1, 0, 1, 2, 0]
display('df', 'df.groupby(L).sum()')
Out[25]:
df
df.groupby(L).sum()
key
data1
data2
data1
data2
0
A
0
5
0
7
17
1
B
1
0
1
4
3
2
C
2
3
2
4
7
3
A
3
3
4
B
4
7
5
C
5
9
Of course, this means there's another, more verbose way of accomplishing the df.groupby('key') from before:
当然,这就表明,前⾯的 df.groupby('key') 语法还有另外⼀种更加有含义的⽅式来实现:
In [26]: display('df', "df.groupby(df['key']).sum()")
Out[26]:
df
df.groupby(df['key']).sum()
key
data1
data2
data1
data2
0
A
0
5
key
1
B
1
0
A
3
8
2
C
2
3
B
5
7
3
A
3
3
C
7
12
4
B
4
7
5
C
5
9
A dictionary or series mapping index to group
使⽤字典或映射索引的序列来分组
Another method is to provide a dictionary that maps index values to the group keys:
还有⼀种⽅法是提供⼀个字典,将索引值映射成分组键:
In [27]: df2 = df.set_index('key')
mapping = {'A': 'vowel', 'B': 'consonant', 'C': 'consonant'}
display('df2', 'df2.groupby(mapping).sum()')
Out[27]:
df2
df2.groupby(mapping).sum()
data1
data2
key
A
0
5
B
1
0
C
2
3
A
3
3
B
4
7
C
5
9
data1
data2
consonant
12
19
vowel
3
8
Any Python function
任何Python函数
Similar to mapping, you can pass any Python function that will input the index value and output the group:
类似映射,你可以传递任何Python函数将输⼊的索引值变成输出的分组键:
In [28]: display('df2', 'df2.groupby(str.lower).mean()')
Out[28]:
df2
df2.groupby(str.lower).mean()
data1
data2
key
data1
data2
a
1.5
4.0
A
0
5
b
2.5
3.5
B
1
0
c
3.5
6.0
C
2
3
A
3
3
B
4
7
C
5
9
A list of valid keys
正确键的列表
Further, any of the preceding key choices can be combined to group on a multi-index:
还有,任何前⾯的多个分组键可以组合并输出成⼀个多重索引的结果:
In [29]: df2.groupby([str.lower, mapping]).mean()
Out[29]:
data1
data2
a
vowel
1.5
4.0
b
consonant
2.5
3.5
c
consonant
3.5
6.0
Grouping example
分组例⼦
As an example of this, in a couple lines of Python code we can put all these together and count discovered planets by
method and by decade:
作为分组的例⼦,我们将前⾯介绍的内容⽤⼏⾏Python代码写出来⽤于计算通过不同⽅法在不同年代发现的⾏星的个数:
In [30]: decade = 10 * (planets['year'] // 10)
decade = decade.astype(str) + 's'
decade.name = 'decade'
planets.groupby(['method', decade])['number'].sum().unstack().fillna(0)
Out[30]:
decade
1980s
1990s
2000s
2010s
Astrometry
0.0
0.0
0.0
2.0
Eclipse Timing Variations
0.0
0.0
5.0
10.0
Imaging
0.0
0.0
29.0
21.0
Microlensing
0.0
0.0
12.0
15.0
Orbital Brightness Modulation
0.0
0.0
0.0
5.0
Pulsar Timing
0.0
9.0
1.0
1.0
Pulsation Timing Variations
0.0
0.0
1.0
0.0
Radial Velocity
1.0
52.0
475.0
424.0
Transit
0.0
0.0
64.0
712.0
Transit Timing Variations
0.0
0.0
0.0
9.0
method
This shows the power of combining many of the operations we've discussed up to this point when looking at realistic
datasets. We immediately gain a coarse understanding of when and how planets have been discovered over the past
several decades!
这个例⼦展⽰了我们结合前⾯介绍过的多种操作之后,我们能在真实的数据集上完成多强⼤的操作。我们⽴即获得了过去⼏⼗年间我们是
如何发现⾏星的⼤概统计。
Here I would suggest digging into these few lines of code, and evaluating the individual steps to make sure you
understand exactly what they are doing to the result. It's certainly a somewhat complicated example, but understanding
these pieces will give you the means to similarly explore your own data.
作者建议你深⼊研究上⾯的⼏⾏代码,逐步的执⾏它们,直到你完全理解了这些代码是如何最终产⽣结果的。当然上⾯是⼀个稍微复杂的
例⼦,但是理解这个例⼦会让你在研究⾃⼰的数据集时知道如何进⾏操作。
<
组合数据集:Merge 和 Join | ⽬录 | 数据透视表 >
Open in Colab
<
聚合与分组 | ⽬录 | 向量化的字符串操作 >
Open in Colab
Pivot Tables
数据透视表
We have seen how the GroupBy abstraction lets us explore relationships within a dataset. A pivot table is a similar
operation that is commonly seen in spreadsheets and other programs that operate on tabular data. The pivot table takes
simple column-wise data as input, and groups the entries into a two-dimensional table that provides a multidimensional
summarization of the data. The difference between pivot tables and GroupBy can sometimes cause confusion; it helps
me to think of pivot tables as essentially a multidimensional version of GroupBy aggregation. That is, you split-applycombine, but both the split and the combine happen across not a one-dimensional index, but across a two-dimensional
grid.
上⼀节我们学习了使⽤ GroupBy 来处理数据集之间的关系。数据透视表也是⼀个类似的操作,我们经常会在电⼦表格或其他处理表格数
据的程序中看到它。数据透视表将列状的数据作为输⼊,然后将它们组合到⼀个⼆维的表格中,通过这种组合结果提供数据在多个维度上
的统计数据。数据透视表和 GroupBy 之间的区别经常会造成⼀些混乱;如果我们将数据透视表想象成⼀个多维版本的 GroupBy 聚合,
会容易很多。也就是说,依然通过拆分-应⽤-组合的步骤,不过不是在⼀维的索引上进⾏,⽽是在⼆维的表格中进⾏。
Motivating Pivot Tables
进⼊数据透视表
For the examples in this section, we'll use the database of passengers on the Titanic, available through the Seaborn
library (see Visualization With Seaborn):
本⼩节的例⼦,我们将采⽤泰坦尼克的乘客数据,同样来⾃Seaborn库(参⻅使⽤Seaborn进⾏可视化):
In [1]: import numpy as np
import pandas as pd
import seaborn as sns
titanic = sns.load_dataset('titanic')
In [2]: titanic.head()
Out[2]:
survived
pclass
sex
age
sibsp
parch
fare
embarked
class
who
adult_male
deck
embark_town
alive
alone
0
0
3
male
22.0
1
0
7.2500
S
Third
man
True
NaN
Southampton
no
False
1
1
1
female
38.0
1
0
71.2833
C
First
woman
False
C
Cherbourg
yes
False
2
1
3
female
26.0
0
0
7.9250
S
Third
woman
False
NaN
Southampton
yes
True
3
1
1
female
35.0
1
0
53.1000
S
First
woman
False
C
Southampton
yes
False
4
0
3
male
35.0
0
0
8.0500
S
Third
man
True
NaN
Southampton
no
True
This contains a wealth of information on each passenger of that ill-fated voyage, including gender, age, class, fare paid,
and much more.
这个数据集包含了每⼀个乘客在他们的那次致命之旅中的很多信息,包括性别、年龄、舱位、票价等等。
Pivot Tables by Hand
⼿动⽣成数据透视表
To start learning more about this data, we might begin by grouping according to gender, survival status, or some
combination thereof. If you have read the previous section, you might be tempted to apply a GroupBy operation–for
example, let's look at survival rate by gender:
在深⼊分析数据之前,我们⾸先根据性别和存活状态的相关性进⾏分组。根据上⼀节的内容,你可能会⾃然⽽然地使⽤ GroupBy 操作,
例如,让我们来获得不同性别的存活率:
In [3]: titanic.groupby('sex')[['survived']].mean()
Out[3]:
survived
sex
female
0.742038
male
0.188908
This immediately gives us some insight: overall, three of every four females on board survived, while only one in five
males survived!
这个结果⽴刻能给我们⼀些数据的内在意义:普遍来说,四分之三的⼥性都存活了下来,⽽只有五分之⼀的男性存活了下来!
This is useful, but we might like to go one step deeper and look at survival by both sex and, say, class. Using the
vocabulary of GroupBy , we might proceed using something like this: we group by class and gender, select survival,
apply a mean aggregate, combine the resulting groups, and then unstack the hierarchical index to reveal the hidden
multidimensionality. In code:
这很有⽤,但是我们可能希望进⼀步了解根据性别和舱位来统计存活率。如果我们⽤ GroupBy 的⽅法来描述这个过程的话,那么很可能
是这样的:我们使⽤舱位和性别来分组,选择存活状态,应⽤平均值聚合操作,将结果的分组组合起来,然后展开成层次化的索引来展⽰
隐藏的⾼维度。代码如下:
In [4]: titanic.groupby(['sex', 'class'])['survived'].aggregate('mean').unstack()
Out[4]:
class
First
Second
Third
female
0.968085
0.921053
0.500000
male
0.368852
0.157407
0.135447
sex
This gives us a better idea of how both gender and class affected survival, but the code is starting to look a bit garbled.
While each step of this pipeline makes sense in light of the tools we've previously discussed, the long string of code is not
particularly easy to read or use. This two-dimensional GroupBy is common enough that Pandas includes a
convenience routine, pivot_table , which succinctly handles this type of multi-dimensional aggregation.
结果给了我们⼀个更好的关于性别和舱位是如何影响存活率的视⻆,但是代码已经开始显得有点混乱和难以阅读了。当我们采⽤之前的知
识来实现这个操作流的每⼀步的时候,代码会变得越来越⻓,将会越来越难以使⽤和阅读。这种⼆维的 GroupBy 对于在Pandas中进⾏普
通分组统计时是⾜够的,⽽透视表 pivot_table ,能简洁的处理这种多维度的聚合操作。
Pivot Table Syntax
数据透视表语法
Here is the equivalent to the preceding operation using the pivot_table method of DataFrame s:
下⾯是我们使⽤ DataFrame 的 pivot_table 来实现这个操作的版本:
In [5]: titanic.pivot_table('survived', index='sex', columns='class')
Out[5]:
class
First
Second
Third
female
0.968085
0.921053
0.500000
male
0.368852
0.157407
0.135447
sex
This is eminently more readable than the groupby approach, and produces the same result. As you might expect of an
early 20th-century transatlantic cruise, the survival gradient favors both women and higher classes. First-class women
survived with near certainty (hi, Rose!), while only one in ten third-class men survived (sorry, Jack!).
上⾯的语法明显⽐ groupby 版本要易读多了,两者的结果是⼀致的。结果告诉我们如果要搭乘20世纪初的跨⼤西洋游轮的话,⽣存⼏率
更加⻘睐于⼥性和⾼级舱位。头等舱⼥性⼏乎全部存活(Rose你好),⽽三等舱的男性只有⼗分之⼀的⼏率存活(Jack抱歉)。
译者注:Jack和Rose是1997年电影《泰坦尼克号》的男⼥主⻆名字,导演是James Carmeron。
Multi-level pivot tables
多层透视表
Just as in the GroupBy , the grouping in pivot tables can be specified with multiple levels, and via a number of options.
For example, we might be interested in looking at age as a third dimension. We'll bin the age using the pd.cut
function:
就像 GroupBy 那样,数据透视表的分组也可以指定多层次,还可以指定其他多个参数。例如,我们可能想要将年龄作为第三个维度。我
们可以使⽤ pd.cut 将年龄进⾏分桶:
In [6]: age = pd.cut(titanic['age'], [0, 18, 80])
titanic.pivot_table('survived', ['sex', age], 'class')
Out[6]:
class
First
Second
Third
sex
age
female
(0, 18]
0.909091
1.000000
0.511628
(18, 80]
0.972973
0.900000
0.423729
(0, 18]
0.800000
0.600000
0.215686
(18, 80]
0.375000
0.071429
0.133663
male
We can apply the same strategy when working with the columns as well; let's add info on the fare paid using pd.qcut
to automatically compute quantiles:
我们也可以将相同的⽅法应⽤到列上;下⾯我们在列上加上船票费⽤分组,使⽤ pd.qcut 将费⽤按⽐例⾃动分桶:
In [7]: fare = pd.qcut(titanic['fare'], 2)
titanic.pivot_table('survived', ['sex', age], [fare, 'class'])
Out[7]:
fare
(-0.001, 14.454]
(14.454, 512.329]
class
First
Second
Third
First
Second
Third
sex
age
female
(0, 18]
NaN
1.000000
0.714286
0.909091
1.000000
0.318182
(18, 80]
NaN
0.880000
0.444444
0.972973
0.914286
0.391304
(0, 18]
NaN
0.000000
0.260870
0.800000
0.818182
0.178571
(18, 80]
0.0
0.098039
0.125000
0.391304
0.030303
0.192308
male
The result is a four-dimensional aggregation with hierarchical indices (see Hierarchical Indexing), shown in a grid
demonstrating the relationship between the values.
结果是⼀个四维的统计表,⾏和列都具有层次化的索引(参⻅层次化索引),以表格的形式展⽰了对应四个不同维度的聚合数据。
Additional pivot table options
其他透视表参数
The full call signature of the pivot_table method of DataFrame s is as follows:
的 pivot_table ⽅法的完整签名如下:
# pivot_table的签名,Pandas版本0.24.2
pd.pivot_table(
data, # DataFrame,当为⽅法时,这⾥是self
values=None, # ⽤来聚合的列
index=None, # ⾏索引,⾏分组的条件
columns=None, # 列索引,列分组的条件
aggfunc='mean', # 聚合函数,默认平均值
fill_value=None, # NA值的替代值
margins=False, # 总计,⾏与列相加的结果
dropna=True, # 是否移除含有NA值的列
margins_name='All', # 总计的⾏和列的标签
DataFrame
)
We've already seen examples of the first three arguments; here we'll take a quick look at the remaining ones. Two of the
options, fill_value and dropna , have to do with missing data and are fairly straightforward; we will not show
examples of them here.
前三个参数(除data外)前⾯的例⼦中已经介绍过了;这⾥我们简单的介绍余下的⼏个参数。其中的 fill_value 和 dropna 与数据集
的缺失值相关,前⾯我们也都看到过;这⾥我们就不举例了。
The aggfunc keyword controls what type of aggregation is applied, which is a mean by default. As in the GroupBy, the
aggregation specification can be a string representing one of several common choices (e.g., 'sum' , 'mean' ,
'count' , 'min' , 'max' , etc.) or a function that implements an aggregation (e.g., np.sum() , min() , sum() ,
etc.). Additionally, it can be specified as a dictionary mapping a column to any of the above desired options:
参数指定数据透视表使⽤的聚合函数,默认是平均值 'mean' 。就像 GroupBy 中⼀样,聚合函数可以通过函数名称的字符串
来指定(例如 'sum' 、 'mean' 、 'count' 、 'min' 、 'max' 等)。除此之外,也可以通过⼀个字典将列与聚合函数对应起来作
为 aggfunc 的参数。
aggfunc
In [8]: titanic.pivot_table(index='sex', columns='class',
aggfunc={'survived':sum, 'fare':'mean'})
Out[8]:
fare
class
survived
First
Second
Third
First
Second
Third
female
106.125798
21.970121
16.118810
91
70
72
male
67.226127
19.741782
12.661633
45
17
47
sex
Notice also here that we've omitted the values keyword; when specifying a mapping for aggfunc , this is determined
automatically.
上⾯的例⼦中, values 参数也被忽略了;当我们将列和聚合函数映射的字典传递到 aggfunc 参数时,进⾏聚合的列显然是不需要指定
的。
At times it's useful to compute totals along each grouping. This can be done via the margins keyword:
很多时候,对每个组进⾏总计(或者⼩计)是很有⽤的。这可以通过指定 margins 参数来计算:
In [9]: titanic.pivot_table('survived', index='sex', columns='class', margins=True)
Out[9]:
class
First
Second
Third
All
female
0.968085
0.921053
0.500000
0.742038
male
0.368852
0.157407
0.135447
0.188908
All
0.629630
0.472826
0.242363
0.383838
sex
Here this automatically gives us information about the class-agnostic survival rate by gender, the gender-agnostic
survival rate by class, and the overall survival rate of 38%. The margin label can be specified with the margins_name
keyword, which defaults to "All" .
结果最后⼀⾏展⽰了所有性别不同舱位的存活率,最后⼀列展⽰了所有舱位不同性别的存活率,⽽右下⻆的数字代表总体存活率,约为
38%。总计(或⼩计)的标签可以通过 margins_name 参数来制定,默认为 "All" 。
Example: Birthrate Data
例⼦:出⽣率数据
As a more interesting example, let's take a look at the freely available data on births in the United States, provided by the
Centers for Disease Control (CDC). This data can be found at https://raw.githubusercontent.com/jakevdp/dataCDCbirths/master/births.csv (this dataset has been analyzed rather extensively by Andrew Gelman and his group; see,
for example, this blog post):
下⾯来看⼀个更有趣的例⼦,使⽤由疾控中⼼提供的可⾃由获取使⽤的美国的⼈⼝出⽣数据。这个数据集可以在
https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv 找到(Andrew Gelman和他的团队深⼊分析了这个数据集,
例如可以参⻅这篇博⽂):
In [10]: # 如果你没有该数据集,可以⽤下⾯这条命令来下载它
# !curl -O https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv
In [10]: births = pd.read_csv('data/births.csv')
Taking a look at the data, we see that it's relatively simple–it contains the number of births grouped by date and gender:
⼤致浏览⼀遍这个数据集,发现它其实相对来说很简单,包括某年某⽉某⽇出⽣的男孩和⼥孩的个体数:
In [11]: births.head()
Out[11]:
year
month
day
gender
births
0
1969
1
1.0
F
4046
1
1969
1
1.0
M
4440
2
1969
1
2.0
F
4454
3
1969
1
2.0
M
4548
4
1969
1
3.0
F
4548
We can start to understand this data a bit more by using a pivot table. Let's add a decade column, and take a look at
male and female births as a function of decade:
我们可以通过使⽤数据透视表来更好的理解这个数据集。让我们加⼀列年代,来看⼀下每⼗年男孩和⼥孩的出⽣总数:
In [12]: births['decade'] = 10 * (births['year'] // 10)
births.pivot_table('births', index='decade', columns='gender', aggfunc='sum')
Out[12]:
gender
F
M
1960
1753634
1846572
1970
16263075
17121550
1980
18310351
19243452
1990
19479454
20420553
2000
18229309
19106428
decade
We immediately see that male births outnumber female births in every decade. To see this trend a bit more clearly, we
can use the built-in plotting tools in Pandas to visualize the total number of births by year (see Introduction to Matplotlib
for a discussion of plotting with Matplotlib):
我们会⽴刻发现男孩的出⽣⼈数在每⼀个年代都超过了⼥孩。为了更加清晰地看到这个趋势,我们可以使⽤Pandas內建的图表⼯具来展⽰
每年的男孩⼥孩的出⽣总数情况(参⻅Matplotlib介绍):
In [13]: %matplotlib inline
import matplotlib.pyplot as plt
sns.set() #
seaborn
births.pivot_table('births', index='year', columns='gender', aggfunc='sum').plot()
plt.ylabel('total births per year');
设置使⽤
⻛格图表
With a simple pivot table and plot() method, we can immediately see the annual trend in births by gender. By eye, it
appears that over the past 50 years male births have outnumbered female births by around 5%.
使⽤⼀个简单的数据透视表和內建的 plot() ⽅法,我们可以很容易的画出区分性别的出⽣数趋势图。⽤⾁眼观测,可知在过去的50年
中,男孩出⽣数⼤致⽐⼥孩出⽣数⾼出5%。
Further data exploration
进⼀步数据分析
Though this doesn't necessarily relate to the pivot table, there are a few more interesting features we can pull out of this
dataset using the Pandas tools covered up to this point. We must start by cleaning the data a bit, removing outliers
caused by mistyped dates (e.g., June 31st) or missing values (e.g., June 99th). One easy way to remove these all at
once is to cut outliers; we'll do this via a robust sigma-clipping operation:
虽然下⾯的内容不⼀定与数据透视表有关,但是我们使⽤⽬前学习到的Pandas知识,就能从数据集中获得更多有趣的特征。⾸先我们应该
对数据进⾏⼀定清洗,删除由于错误输⼊⽇期导致的离群值(例如6⽉31⽇)或者缺失值(例如6⽉99⽇)。⼀次性删除这些离群数据的简
单办法是通过⼀种叫sigma-clipping的稳健统计操作:
求出出⽣数的
和 位置的值
为中位数
In [14]: #
25%,50% 75%
quartiles = np.percentile(births['births'], [25, 50, 75])
mu = quartiles[1] # mu
sig = 0.74 * (quartiles[2] - quartiles[0]) # sigma
75%
的值为 位置与25%位置差的0.74倍
This final line is a robust estimate of the sample mean, where the 0.74 comes from the interquartile range of a Gaussian
distribution (You can learn more about sigma-clipping operations in a book I coauthored with Željko Ivezić, Andrew J.
Connolly, and Alexander Gray: "Statistics, Data Mining, and Machine Learning in Astronomy" (Princeton University Press,
2014)).
最后⼀⾏代码是样本平均的稳健估计,0.74来源于标准正态分布的四分位距(你可以在作者与Željko Ivezić、Andrew J. Connolly和
Alexander Gray合著的书"Statistics, Data Mining, and Machine Learning in Astronomy"(Princeton University Press, 2014)中学习到更多有
关sigma-clipping⽅法的知识)。
译者注:对于标准正态分布来说,均值为0,四分位距位于[-0.67448, 0, 0.67448]的位置,因此 IQR = Q3 - Q1 = 0.67448- (-0.67448) =
1.34896,得
。可以⽤以下代码进⾏简单验证:
1
1.34896
= 0.74131
In [20]: a = np.random.standard_normal(10000)
In [21]: iq = np.percentile(a, [25, 50, 75])
In [22]: iq
Out[22]: array([-0.6510475 ,
0.02099125,
0.68378426])
In [23]: 1/(iq[2] - iq[0])
Out[23]: 0.749158077468436
With this we can use the query() method (discussed further in High-Performance Pandas: eval() and query() )
to filter-out rows with births outside these values:
然后我们可以使⽤ query() ⽅法来过滤掉偏离中位数5倍sigma值之外的所有数据( query() ⽅法我们会在⾼性能Pandas: eval() 和
query()⼩节中详细讨论):
In [15]: births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')
Next we set the day column to integers; previously it had been a string because some columns in the dataset contained
the value 'null' :
下⾯我们将⽇期 day 列设为整数类型;原本该列具有字符串类型因为数据集中该列存在值 'null' :
将 列设置为整数类型
In [16]: #
day
births['day'] = births['day'].astype(int)
Finally, we can combine the day, month, and year to create a Date index (see Working with Time Series). This allows us
to quickly compute the weekday corresponding to each row:
最后,我们可以将年⽉⽇合并在⼀起成为⼀个时间序列(参⻅在时间序列上操作)。这令我们可以很⽅便的求出每⼀⾏⽇期是周⼏:
In [17]: # 使⽤年⽉⽇构造⼀个时间序列
births.index = pd.to_datetime(10000 * births.year +
100 * births.month +
births.day, format='%Y%m%d')
births['dayofweek'] = births.index.dayofweek
Using this we can plot births by weekday for several decades:
然后我们就可以按照星期中的天数来绘制出⽣数图:
In [18]: import matplotlib.pyplot as plt
import matplotlib as mpl
births.pivot_table('births', index='dayofweek',
columns='decade', aggfunc='mean').plot()
plt.gca().set_xticklabels(['Mon', 'Tues', 'Wed', 'Thurs', 'Fri', 'Sat', 'Sun'])
plt.ylabel('mean births by day');
Apparently births are slightly less common on weekends than on weekdays! Note that the 1990s and 2000s are missing
because the CDC data contains only the month of birth starting in 1989.
很明显,出⽣数在休息⽇要⽐⼯作⽇少。还要注意到1990和2000年代数据缺失,原因是疾控中⼼的数据从1989年开始就只包含⽉份信息
了。
Another intersting view is to plot the mean number of births by the day of the year. Let's first group the data by month and
day separately:
另⼀个有趣的视⻆是分析每年每天的平均出⽣数。⾸先我们将⽉份和⽇期进⾏分组求平均值:
In [19]: births_by_date = births.pivot_table('births',
[births.index.month, births.index.day])
births_by_date.head()
Out[19]:
births
1
1
4009.225
2
4247.400
3
4500.900
4
4571.350
5
4603.625
The result is a multi-index over months and days. To make this easily plottable, let's turn these months and days into a
date by associating them with a dummy year variable (making sure to choose a leap year so February 29th is correctly
handled!)
结果当然,是⼀个⽉份和⽇期的多重索引数据集。然后需要简单的绘制图表,我们可以将上⾯的⽉份⽇期随便放在⼀个闰年年份中形成完
整的时间序列(闰年是为了保证2⽉29⽇也能包含在结果集中):
In [20]: births_by_date.index = [pd.datetime(2012, month, day)
for (month, day) in births_by_date.index]
births_by_date.head()
Out[20]:
births
2012-01-01
4009.225
2012-01-02
4247.400
2012-01-03
4500.900
2012-01-04
4571.350
2012-01-05
4603.625
Focusing on the month and day only, we now have a time series reflecting the average number of births by date of the
year. From this, we can use the plot method to plot the data. It reveals some interesting trends:
我们只需要关注数据集中的⽉份和⽇期,上⾯的结果已经是⼀个时间序列上每天出⽣数的平均值。然后我们就可以使⽤ plot ⽅法来绘制
图表。结果会反映⼀些有趣的趋势:
绘制每年每天的出⽣数平均值
In [21]: #
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);
In particular, the striking feature of this graph is the dip in birthrate on US holidays (e.g., Independence Day, Labor Day,
Thanksgiving, Christmas, New Year's Day) although this likely reflects trends in scheduled/induced births rather than
some deep psychosomatic effect on natural births. For more discussion on this trend, see the analysis and links in
Andrew Gelman's blog post on the subject. We'll return to this figure in Example:-Effect-of-Holidays-on-US-Births, where
we will use Matplotlib's tools to annotate this plot.
上图这个引⼈注⽬的结果表明出⽣率在美国假期的⽇期中都会下降(例如独⽴⽇、劳动节、感恩节、圣诞节和新年),当然这并不代表节
⽇对⾃然出⽣率在⽣物学上造成了影响,⽽是反映了医学上对⽣育的额外处理的趋势。对于有关这个趋势更多的讨论,可以参看Andrew
Gelman's blog post。我们会在第四章第九节例⼦:节⽇对美国⽣育率的影响中继续深⼊讨论这个图表,学习使⽤Matplotlib⼯具来标注这个
图表。
Looking at this short example, you can see that many of the Python and Pandas tools we've seen to this point can be
combined and used to gain insight from a variety of datasets. We will see some more sophisticated applications of these
data manipulations in future sections!
看完了这个简短的例⼦,你就可以发现我们已经学习到的很多Python和Pandas的⼯具可以联合使⽤来深⼊分析不同的数据集以获得需要的
结果。在后续章节中我们会看到⼀些对于数据的操作更复杂的应⽤。
<
聚合与分组 | ⽬录 | 向量化的字符串操作 >
Open in Colab
<
数据透视表 | ⽬录 | 在时间序列上操作 >
Open in Colab
Vectorized String Operations
向量化的字符串操作
One strength of Python is its relative ease in handling and manipulating string data. Pandas builds on this and provides a
comprehensive set of vectorized string operations that become an essential piece of the type of munging required when
working with (read: cleaning up) real-world data. In this section, we'll walk through some of the Pandas string operations,
and then take a look at using them to partially clean up a very messy dataset of recipes collected from the Internet.
的⼀个强⼤的特点就是它能相对简单的处理和操作字符串数据。Pandas在此基础上提供了⼀整套向量化字符串操作,这成为了当我
们处理(清洗)真实世界数据时⾮常关键的功能。在本节中,我们将对很多Pandas的字符串操作进⾏介绍,然后看它们在我们对从互联⽹
采集到的⾮常不规范的数据集进⾏清洗时发挥的作⽤。
Python
Introducing Pandas String Operations
字符串操作介绍
Pandas
We saw in previous sections how tools like NumPy and Pandas generalize arithmetic operations so that we can easily
and quickly perform the same operation on many array elements. For example:
在前⾯章节中我们看到NumPy和Pandas的⼯具能向量化算术运算,让我们可以很容易和快速的对数组的元素进⾏相同的数学计算。例
如:
In [1]: import numpy as np
x = np.array([2, 3, 5, 7, 11, 13])
x * 2
Out[1]: array([ 4,
6, 10, 14, 22, 26])
This vectorization of operations simplifies the syntax of operating on arrays of data: we no longer have to worry about the
size or shape of the array, but just about what operation we want done. For arrays of strings, NumPy does not provide
such simple access, and thus you're stuck using a more verbose loop syntax:
这种向量化的操作能简化数组元素的操作语法:我们不再需要担⼼数组的⼤⼩和形状,只需要关注于需要进⾏的运算本⾝。对于字符串数
组,NumPy没有提供这种简单的操作,因此你需要继续使⽤循环语法来处理:
In [2]: data = ['peter', 'Paul', 'MARY', 'gUIDO']
[s.capitalize() for s in data]
Out[2]: ['Peter', 'Paul', 'Mary', 'Guido']
This is perhaps sufficient to work with some data, but it will break if there are any missing values. For example:
这可能对于⼀些数据集来说⾜够了,但是对于含有缺失值的数据集来说就出问题了。例如:
In [3]: data = ['peter', 'Paul', None, 'MARY', 'gUIDO']
[s.capitalize() for s in data]
--------------------------------------------------------------------------AttributeError
Traceback (most recent call last)
<ipython-input-3-3b0264c38d59> in <module>
1 data = ['peter', 'Paul', None, 'MARY', 'gUIDO']
----> 2 [s.capitalize() for s in data]
<ipython-input-3-3b0264c38d59> in <listcomp>(.0)
1 data = ['peter', 'Paul', None, 'MARY', 'gUIDO']
----> 2 [s.capitalize() for s in data]
AttributeError: 'NoneType' object has no attribute 'capitalize'
Pandas includes features to address both this need for vectorized string operations and for correctly handling missing
data via the str attribute of Pandas Series and Index objects containing strings. So, for example, suppose we create a
Pandas Series with this data:
包含了前⾯说到的向量化的字符串操作,⽽且还能正确的处理缺失值,这可以通过Pandas的Series和Index对象的 str 属性来实
现。例如,假设我们如下创建⼀个Pandas Series:
Pandas
In [4]: import pandas as pd
names = pd.Series(data)
names
Out[4]: 0
peter
1
Paul
2
None
3
MARY
4
gUIDO
dtype: object
We can now call a single method that will capitalize all the entries, while skipping over any missing values:
我们现在可以调⽤⼀个⽅法就能将所有元素⾸字⺟变⼤写的功能,并能跳过缺失值:
In [5]: names.str.capitalize()
Out[5]: 0
Peter
1
Paul
2
None
3
Mary
4
Guido
dtype: object
Using tab completion on this str attribute will list all the vectorized string methods available to Pandas.
在IPython中在 str 属性上使⽤制表符⾃动补全功能可以列出Pandas中⽀持的所有的向量化字符串操作。
Tables of Pandas String Methods
字符串⽅法列表
Pandas
If you have a good understanding of string manipulation in Python, most of Pandas string syntax is intuitive enough that
it's probably sufficient to just list a table of available methods; we will start with that here, before diving deeper into a few
of the subtleties. The examples in this section use the following series of names:
如果你已经很好理解了Python中的字符串操作,⼤多数的Pandas字符串操作语法是很直观的,因此简单的列⼀张表格说明所有可⽤的⽅法
就⾜够理解了;我们在深⼊到⼀些细节之前可以先浏览这张表。本节的例⼦将会使⽤下⾯的⼀个姓名的Series对象:
In [6]: monte = pd.Series(['Graham Chapman', 'John Cleese', 'Terry Gilliam',
'Eric Idle', 'Terry Jones', 'Michael Palin'])
Methods similar to Python string methods
类似Python的字符串⽅法
Nearly all Python's built-in string methods are mirrored by a Pandas vectorized string method. Here is a list of Pandas
str methods that mirror Python string methods:
⼏乎所有Python內建的字符串⽅法都有Pandas的向量化版本。下⾯是Pandas的 str 属性中与Python內建字符串⽅法⼀致的⽅法:
len()
lower()
translate()
islower()
ljust()
upper()
startswith()
isupper()
rjust()
find()
endswith()
isnumeric()
center()
rfind()
isalnum()
isdecimal()
zfill()
index()
isalpha()
split()
strip()
rindex()
isdigit()
rsplit()
rstrip()
capitalize()
isspace()
partition()
lstrip()
swapcase()
istitle()
rpartition()
Notice that these have various return values. Some, like lower() , return a series of strings:
要提醒的是,这些⽅法与內建字符串⽅法可能有着不同的返回值,如 lower() 返回的是⼀个字符串的Series对象:
In [7]: monte.str.lower()
Out[7]: 0
graham chapman
1
john cleese
2
terry gilliam
3
eric idle
4
terry jones
5
michael palin
dtype: object
But some others return numbers:
另外⼀些返回的是数字的Series对象:
In [8]: monte.str.len()
Out[8]: 0
14
1
11
2
13
3
9
4
11
5
13
dtype: int64
Or Boolean values:
或布尔值的Series对象:
In [9]: monte.str.startswith('T')
Out[9]: 0
False
1
False
2
True
3
False
4
True
5
False
dtype: bool
Still others return lists or other compound values for each element:
还有⼀些会返回诸如列表那样的复合类型的Series对象:
In [10]: monte.str.split()
Out[10]: 0
[Graham, Chapman]
1
[John, Cleese]
2
[Terry, Gilliam]
3
[Eric, Idle]
4
[Terry, Jones]
5
[Michael, Palin]
dtype: object
We'll see further manipulations of this kind of series-of-lists object as we continue our discussion.
我们后⾯会讨论到如何操作这种列表组成的Series对象。
Methods using regular expressions
使⽤正则表达式的⽅法
In addition, there are several methods that accept regular expressions to examine the content of each string element, and
follow some of the API conventions of Python's built-in re module:
除此之外,还有⼀些⽅法可以接受正则表达式来检查每个元素字符串是否匹配模式,它们遵从Python內建的 re 模块的API规范:
⽅法
match()
extract()
findall()
replace()
contains()
count()
split()
rsplit()
描述
在每个元素上调⽤ re.match() ⽅法,返回布尔类型Series
在每个元素上调⽤ re.match() ⽅法,返回匹配到模式的正则分组的Series
在每个元素上调⽤ re.findall() ⽅法
将匹配模式的字符串部分替换成其他字符串值
在每个元素上调⽤ re.search() ,返回布尔类型Series
计算匹配到模式的次数
等同于 str.split() ,但是能接受正则表达式参数
等同于 str.rsplit() , 但是能接受正则表达式参数
With these, you can do a wide range of interesting operations. For example, we can extract the first name from each by
asking for a contiguous group of characters at the beginning of each element:
使⽤上⾯的⽅法,你可以执⾏很多有趣的操作。例如,我们可以通过匹配连续的⼀组字⺟的模式从姓名中提取出名字:
In [11]: monte.str.extract('([A-Za-z]+)', expand=False)
Out[11]: 0
Graham
1
John
2
Terry
3
Eric
4
Terry
5
Michael
dtype: object
Or we can do something more complicated, like finding all names that start and end with a consonant, making use of the
start-of-string ( ^ ) and end-of-string ( $ ) regular expression characters:
或者我们可以执⾏更复杂的操作,如找出所有姓名中⾸字⺟和尾字⺟都是辅⾳字⺟的⼈,这⾥需要使⽤字符串开始位置( ^ )和字符串结
束位置( $ )正则表达式特殊符号:
In [12]: monte.str.findall(r'^[^AEIOU].*[^aeiou]$')
Out[12]: 0
[Graham Chapman]
1
[]
2
[Terry Gilliam]
3
[]
4
[Terry Jones]
5
[Michael Palin]
dtype: object
The ability to concisely apply regular expressions across Series or Dataframe entries opens up many possibilities
for analysis and cleaning of data.
这种在 Series 或 DataFrame 上简洁的应⽤正则表达式的特性,在清洗和分析数据任务中⾮常有⽤。
Miscellaneous methods
其他⽅法
Finally, there are some miscellaneous methods that enable other convenient operations:
最后,下⾯是⼀些⽆法分类的其他⽅法但也是很⽅便的字符串功能:
⽅法
描述
get()
对每个元素使⽤索引值获取字符中的字符
slice()
对每个元素进⾏字符串切⽚
slice_replace()
将每个元素的字符串切⽚替换成另⼀个字符串值
cat()
将所有字符串元素连接成⼀个字符串
repeat()
对每个字符串元素进⾏重复操作
normalize()
返回字符串的unicode标准化结果
pad()
字符串对⻬
wrap()
字符串换⾏
join()
字符串中字符的连接
get_dummies() 将字符串按照分隔符分割后形成⼀个⼆维的dummy DataFrame
Vectorized item access and slicing
向量化的索引和切⽚操作
The get() and slice() operations, in particular, enable vectorized element access from each array. For example,
we can get a slice of the first three characters of each array using str.slice(0, 3) . Note that this behavior is also
available through Python's normal indexing syntax–for example, df.str.slice(0, 3) is equivalent to
df.str[0:3] :
和 slice() 操作,可以对每个字符串元素进⾏索引访问和切⽚的操作。例如,我们可以通过 str.slice(0, 3) 获取每个字符
串元素的前三个字⺟。还需要说明的是,这个操作也可以通过Python标准的切⽚语法来完成,也就是说 df.str[:3] 等同于
df.str.slice(0, 3) :
get()
In [13]: monte.str[:3]
Out[13]: 0
Gra
1
Joh
2
Ter
3
Eri
4
Ter
5
Mic
dtype: object
Indexing via df.str.get(i) and df.str[i] is likewise similar.
索引取值操作也是⼀样, df.str[i] 等同于 df.str.get(i) 。
These get() and slice() methods also let you access elements of arrays returned by split() . For example, to
extract the last name of each entry, we can combine split() and get() :
get()
姓:
和 slice() ⽅法还能⽀持对 split() 返回的列表进⾏取值操作。例如我们使⽤ split() 和 get() ⽅法可以提取出每个⼈的
In [14]: monte.str.split().str.get(-1)
Out[14]: 0
Chapman
1
Cleese
2
Gilliam
3
Idle
4
Jones
5
Palin
dtype: object
Indicator variables
指⽰器变量
Another method that requires a bit of extra explanation is the get_dummies() method. This is useful when your data
has a column containing some sort of coded indicator. For example, we might have a dataset that contains information in
the form of codes, such as A="born in America," B="born in the United Kingdom," C="likes cheese," D="likes spam":
还有⼀个需要进⾏说明的⽅法是 get_dummies() 。这个⽅法在你的数据中含有某种编码的指⽰器的时候⾮常有⽤。例如,我们有⼀个数
据集,⾥⾯有⼀个编码了的列,A代表“出⽣在美国”,B代表“出⽣在英国”,C代表“喜欢芝⼠”,D代表“喜欢⾁罐头”:
In [15]: full_monte = pd.DataFrame({'name': monte,
'info': ['B|C|D', 'B|D', 'A|C',
'B|D', 'B|C', 'B|C|D']})
full_monte
Out[15]:
name
info
0
Graham Chapman
B|C|D
1
John Cleese
B|D
2
Terry Gilliam
A|C
3
Eric Idle
B|D
4
Terry Jones
B|C
5
Michael Palin
B|C|D
The get_dummies() routine lets you quickly split-out these indicator variables into a DataFrame :
get_dummies()
⽅法能让你快速的将这些编码的指⽰器变量分解出来,并形成⼀个 DataFrame :
In [16]: full_monte['info'].str.get_dummies('|')
Out[16]:
A
B
C
D
0
0
1
1
1
1
0
1
0
1
2
1
0
1
0
3
0
1
0
1
4
0
1
1
0
5
0
1
1
1
With these operations as building blocks, you can construct an endless range of string processing procedures when
cleaning your data.
有了上述的这些向量化字符串⽅法,你可以在清洗数据时构建⽆穷⽆尽的字符串处理流程。
We won't dive further into these methods here, but I encourage you to read through "Working with Text Data" in the
Pandas online documentation, or to refer to the resources listed in Further Resources.
在这⾥我们不在深⼊介绍每个⽅法,作者推荐你去阅读Pandas的在线⽂档处理⽂本数据,或者参考更多资源中的其他资料。
Example: Recipe Database
例⼦:菜谱数据库
These vectorized string operations become most useful in the process of cleaning up messy, real-world data. Here I'll
walk through an example of that, using an open recipe database compiled from various sources on the Web. Our goal will
be to parse the recipe data into ingredient lists, so we can quickly find a recipe based on some ingredients we have on
hand.
上述介绍的这些向量化字符串操作是我们对不规范的真实世界数据进⾏清洗的最有效⼯具。下⾯我们将使⽤⽹络中收集的⼀个菜谱数据库
作为例⼦来总体说明。我们的⽬标是将这些菜谱数据解析成配⽅的列表,这样我们就能很快速的根据我们⼿头的材料找到相应配⽅的菜
谱。
The scripts used to compile this can be found at https://github.com/fictivekin/openrecipes, and the link to the current
version of the database is found there as well.
⽤来收集菜谱的脚本可以在https://github.com/fictivekin/openrecipes 这⾥找到,最新版本的菜谱数据库的连接也在这个⻚⾯上。
As of Spring 2016, this database is about 30 MB, and can be downloaded and unzipped with these commands:
到2016年春天,这个数据库有⼤约30MB⼤⼩,可以通过下⾯的命令下载以及解压:
译者注:open recipes数据库在2017年后已经⽆法使⽤原作者的地址进⾏下载,参⻅Issue#179。使⽤新的地址可以正确下载数据库内容。
下述shell代码也使⽤新地址进⾏了修改。
In [17]: # !curl -O https://s3.amazonaws.com/openrecipes/20170107-061401-recipeitems.json.gz
# !gunzip 20170107-061401-recipeitems.json.gz
The database is in JSON format, so we will try pd.read_json to read it:
数据库是JSON格式,因此我们需要使⽤ pd.read_json ⽅法来读取它:
In [18]: try:
recipes = pd.read_json('data/20170107-061401-recipeitems.json')
except ValueError as e:
print("ValueError:", e)
ValueError: Trailing data
Oops! We get a ValueError mentioning that there is "trailing data." Searching for the text of this error on the Internet,
it seems that it's due to using a file in which each line is itself a valid JSON, but the full file is not. Let's check if this
interpretation is true:
喔噢,这⾥会产⽣⼀个 ValueError 指出有冗余的数据。通过在⽹上搜索这个错误信息,我们得到原因是这个⽂件每⼀⾏都是⼀个正确
的JSON,但是整个⽂件不是正确的JSON格式。我们来验证⼀下:
In [19]: with open('data/20170107-061401-recipeitems.json') as f:
line = f.readline()
pd.read_json(line).shape
Out[19]: (2, 12)
Yes, apparently each line is a valid JSON, so we'll need to string them together. One way we can do this is to actually
construct a string representation containing all these JSON entries, and then load the whole thing with pd.read_json :
通过读取⽂件⼀⾏我们验证了我们的想法,现在我们需要将这些正确的JSON⾏合并在⼀起。实现这个⽬标的⼀种⽅式就是我们⼿动将所
有的⾏合并成⼀个JSON Array,然后将这个JSON Array的字符串传递到 pd.read_json 来进⾏解析:
将每⼀⾏ 对象合并成⼀整个
提取每⼀⾏
每两个对象之间⽤ 分隔,最后⾸尾加上中括号表⽰数组
将结果字符串作为 格式读取到
中
In [20]: #
JSON
JSON Array
with open('data/20170107-061401-recipeitems.json', 'r') as f:
#
data = (line.strip() for line in f)
#
,
data_json = "[{0}]".format(','.join(data))
#
JSON
Pandas
recipes = pd.read_json(data_json)
In [21]: recipes.shape
Out[21]: (173278, 17)
We see there are nearly 200,000 recipes, and 17 columns. Let's take a look at one row to see what we have:
从形状可知有接近20万个菜谱,每个菜谱有17列数据。看看其中的⼀⾏:
In [22]: recipes.iloc[0]
Out[22]: _id
{'$oid': '5160756b96cc62079cc2db15'}
cookTime
PT30M
creator
NaN
dateModified
NaN
datePublished
2013-03-11
description
Late Saturday afternoon, after Marlboro Man ha...
image
http://static.thepioneerwoman.com/cooking/file...
ingredients
Biscuits\n3 cups All-purpose Flour\n2 Tablespo...
name
Drop Biscuits and Sausage Gravy
prepTime
PT10M
recipeCategory
NaN
recipeInstructions
NaN
recipeYield
12
source
thepioneerwoman
totalTime
NaN
ts
{'$date': 1365276011104}
url
http://thepioneerwoman.com/cooking/2013/03/dro...
Name: 0, dtype: object
There is a lot of information there, but much of it is in a very messy form, as is typical of data scraped from the Web. In
particular, the ingredient list is in string format; we're going to have to carefully extract the information we're interested in.
Let's start by taking a closer look at the ingredients:
结果中有很多的数据,但⼤多数列的数据都是混乱不堪的,正如所有从⽹络中爬取的数据⼀样。特别注意到,配⽅列表是⼀个字符串的格
式,因此我们需要特别⼩⼼的在我们感兴趣的列中进⾏数据提取操作。让我们⼤致看⼀些配⽅列的基本情况:
In [23]: recipes.ingredients.str.len().describe()
Out[23]: count
173278.000000
mean
244.617926
std
146.705285
min
0.000000
25%
147.000000
50%
221.000000
75%
314.000000
max
9067.000000
Name: ingredients, dtype: float64
The ingredient lists average 250 characters long, with a minimum of 0 and a maximum of nearly 10,000 characters!
配⽅的列表平均有250个字符⻓,最短的是0个字符,⽽最⻓的能达到接近10000个字符。
Just out of curiousity, let's see which recipe has the longest ingredient list:
为了满⾜⼀下好奇⼼,让我们看看那个菜谱有着最⻓的配⽅列表:
译者注:在Series上使⽤ argmax() ⽅法或者 np.argmax(series) 函数会出现下⾯的警告,原因后续Pandas版本会修改它们的⾏
为,⽬前版本这两种⽅法返回的是最⼤值所在的⾏,后续会修改为返回⾏号的最⼤值。建议使⽤ Series.idxmax() ⽅法替代。译者后
续增加了使⽤ idxmax() ⽅法的版本。
In [24]: recipes.name[np.argmax(recipes.ingredients.str.len())]
/home/wangy/anaconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:56: FutureWarning:
The current behaviour of 'Series.argmax' is deprecated, use 'idxmax'
instead.
The behavior of 'argmax' will be corrected to return the positional
maximum in the future. For now, use 'series.values.argmax' or
'np.argmax(np.array(values))' to get the position of the maximum
row.
return getattr(obj, method)(*args, **kwds)
Out[24]: 'Carrot Pineapple Spice &amp; Brownie Layer Cake with Whipped Cream &amp; Cream Cheese Frosting and M
arzipan Carrots'
In [25]: recipes.name[recipes.ingredients.str.len().idxmax()]
Out[25]: 'Carrot Pineapple Spice &amp; Brownie Layer Cake with Whipped Cream &amp; Cream Cheese Frosting and M
arzipan Carrots'
That certainly looks like an involved recipe.
这个菜谱看起来就很复杂的样⼦。
We can do other aggregate explorations; for example, let's see how many of the recipes are for breakfast food:
我们可以继续研究⼀下这个数据集,例如,看看有多少种早餐的菜谱:
In [26]: recipes.description.str.contains('[Bb]reakfast').sum()
Out[26]: 3524
Or how many of the recipes list cinnamon as an ingredient:
或者多少中菜谱中⽤到了⾁桂做原料:
In [27]: recipes.ingredients.str.contains('[Cc]innamon').sum()
Out[27]: 10526
We could even look to see whether any recipes misspell the ingredient as "cinamon":
我们甚⾄可以找到有多少种菜谱将⾁桂拼写成了“⾁挂”(cinamon):
In [28]: recipes.ingredients.str.contains('[Cc]inamon').sum()
Out[28]: 11
This is the type of essential data exploration that is possible with Pandas string tools. It is data munging like this that
Python really excels at.
这些类型的基础数据分析⼯作,都可以通过Pandas的字符串⼯具进⾏并获得结果。这正是Python在数据科学领域优于其他语⾔的地⽅。
A simple recipe recommender
⼀个简单的菜单推荐器
Let's go a bit further, and start working on a simple recipe recommendation system: given a list of ingredients, find a
recipe that uses all those ingredients. While conceptually straightforward, the task is complicated by the heterogeneity of
the data: there is no easy operation, for example, to extract a clean list of ingredients from each row. So we will cheat a
bit: we'll start with a list of common ingredients, and simply search to see whether they are in each recipe's ingredient list.
For simplicity, let's just stick with herbs and spices for the time being:
我们再深⼊的研究⼀点,来试试实现⼀个简单的菜谱推荐系统:给定⼀系列的原材料组成的配⽅,找到应⽤了所有原料的菜谱。虽然看起
来很容易,实际上这个任务的难点在于数据的异构性:即⽆法找到⼀个简单的操作,能从每⼀⾏中提取出⼲净的原料列表。因此我们来做
个弊:我们构建⼀个很通⽤的原料列表,然后在每个菜谱的配⽅中搜索它们是否存在。为简单起⻅,⾸先从草药和⾹料开始:
In [29]: spice_list = ['salt', 'pepper', 'oregano', 'sage', 'parsley',
'rosemary', 'tarragon', 'thyme', 'paprika', 'cumin']
We can then build a Boolean DataFrame consisting of True and False values, indicating whether this ingredient
appears in the list:
然后就能创建⼀个布尔 DataFrame ,显⽰上述列表的原料是否在每个菜谱中存在:
In [30]: import re
spice_df = pd.DataFrame(dict((spice, recipes.ingredients.str.contains(spice, re.IGNORECASE))
for spice in spice_list))
spice_df.head()
Out[30]:
salt
pepper
oregano
sage
parsley
rosemary
tarragon
thyme
paprika
cumin
0
False
False
False
True
False
False
False
False
False
False
1
False
False
False
False
False
False
False
False
False
False
2
True
True
False
False
False
False
False
False
False
True
3
False
False
False
False
False
False
False
False
False
False
4
False
False
False
False
False
False
False
False
False
False
Now, as an example, let's say we'd like to find a recipe that uses parsley, paprika, and tarragon. We can compute this
very quickly using the query() method of DataFrame s, discussed in High-Performance Pandas: eval() and
query() :
有了这个布尔类型的数据集,⽐⽅说如果我们想要找到⼀个配⽅使⽤了⾹菜、辣椒和⻰蒿,那么我们就可以很⽅便的使⽤ DataFrame 的
query() ⽅法来找到这些配⽅( query ⽅法详⻅⾼性能Pandas: eval() 和 query()):
In [31]: selection = spice_df.query('parsley & paprika & tarragon')
len(selection)
Out[31]: 10
We find only 10 recipes with this combination; let's use the index returned by this selection to discover the names of the
recipes that have this combination:
这种原料组合的菜谱只有10种;我们可以⽤结果的⾏索引将这些菜谱的名称筛选出来:
In [32]: recipes.name[selection.index]
Out[32]: 2069
All cremat with a Little Gem, dandelion and wa...
74964
Lobster with Thermidor butter
93768
Burton's Southern Fried Chicken with White Gravy
113926
Mijo's Slow Cooker Shredded Beef
137686
Asparagus Soup with Poached Eggs
140530
Fried Oyster Po’boys
158475
Lamb shank tagine with herb tabbouleh
158486
Southern fried chicken in buttermilk
163175
Fried Chicken Sliders with Pickles + Slaw
165243
Bar Tartine Cauliflower Salad
Name: name, dtype: object
Now that we have narrowed down our recipe selection by a factor of almost 20,000, we are in a position to make a more
informed decision about what we'd like to cook for dinner.
现在我们就已经将可选择的菜谱范围缩⼩了将近两万倍,我们可以很容易的对晚餐的菜谱做出最后的选择。
Going further with recipes
更加深⼊的分析菜谱数据库
Hopefully this example has given you a bit of a flavor (ba-dum!) for the types of data cleaning operations that are
efficiently enabled by Pandas string methods. Of course, building a very robust recipe recommendation system would
require a lot more work! Extracting full ingredient lists from each recipe would be an important piece of the task;
unfortunately, the wide variety of formats used makes this a relatively time-consuming process. This points to the truism
that in data science, cleaning and munging of real-world data often comprises the majority of the work, and Pandas
provides the tools that can help you do this efficiently.
希望上⾯的推荐器例⼦能够让你认识到了使⽤Pandas的字符串⽅法我们可以对数据进⾏异常⽅便的清洗操作。当然如果希望构建⼀个成熟
的菜谱推荐系统的话,需要⽐上例复杂的多的技巧和⼯程。将每个菜谱中的原料配⽅提取出来变成⼀个列表会是其中很重要的⼀环;不幸
的是,因为数据格式的多样性,这项任务会相对很耗时。这阐述了数据科学中的⼀个事实,那就是清洗和预处理真实世界的数据是这个领
域⾮常主要的⼯作之⼀,Pandas提供了⼀些⼯具能帮助你很有效率的完成它。
<
数据透视表 | ⽬录 | 在时间序列上操作 >
Open in Colab
<
向量化的字符串操作 | ⽬录 | ⾼性能Pandas: eval() 和 query() >
Open in Colab
Working with Time Series
在时间序列上操作
Pandas was developed in the context of financial modeling, so as you might expect, it contains a fairly extensive set of
tools for working with dates, times, and time-indexed data. Date and time data comes in a few flavors, which we will
discuss here:
Time stamps reference particular moments in time (e.g., July 4th, 2015 at 7:00am).
Time intervals and periods reference a length of time between a particular beginning and end point; for example, the
year 2015. Periods usually reference a special case of time intervals in which each interval is of uniform length and
does not overlap (e.g., 24 hour-long periods comprising days).
Time deltas or durations reference an exact length of time (e.g., a duration of 22.56 seconds).
的发展过程具有很强的⾦融领域背景,因此你可以预料的是,它⼀定包括⼀整套⼯具⽤于处理⽇期、时间和时间索引数据。⽇期和
时间数据有如下⼏类来源,我们会在本节中进⾏讨论:
时间戳 代表着⼀个特定的时间点(例如2015年7⽉4⽇上午7点)。
时间间隔和周期 代表着从开始时间点到结束时间点之间的时间单位⻓度;例如2015⼀整年。周期通常代表⼀段特殊的时间间隔,每个
时间间隔的⻓度都是统⼀的,彼此之间不重叠(例如⼀天由24个⼩时组成)。
时间差或持续时间代表这⼀段准确的时间⻓度(例如22.56秒持续时间)。
Pandas
In this section, we will introduce how to work with each of these types of date/time data in Pandas. This short section is
by no means a complete guide to the time series tools available in Python or Pandas, but instead is intended as a broad
overview of how you as a user should approach working with time series. We will start with a brief discussion of tools for
dealing with dates and times in Python, before moving more specifically to a discussion of the tools provided by Pandas.
After listing some resources that go into more depth, we will review some short examples of working with time series data
in Pandas.
在本节中,我们将介绍在Pandas中如何使⽤上述的这些时间类型数据。这个简短的⼩节不可能覆盖Python或Pandas中所有时间序列⼯具
的内容,但可以作为引导你⼊⻔使⽤它们的⼀个概述。我们⾸先简要介绍⼀些在Python当中处理⽇期时间的⼯具,然后再进⼊到Pandas提
供的相应⼯具的详细介绍上。在列出更加深⼊的学习资源之后,我们还会使⽤⼀些简短的例⼦在说明在Pandas中怎样处理时间序列数据。
Dates and Times in Python
中的⽇期和时间
Python
The Python world has a number of available representations of dates, times, deltas, and timespans. While the time series
tools provided by Pandas tend to be the most useful for data science applications, it is helpful to see their relationship to
other packages used in Python.
本⾝就带有很多有关⽇期、时间、时间差和间隔的表⽰⽅法。Pandas提供的时间序列⼯具在数据科学领域会更加的强⼤,但是⾸先
学习相关的Python的⼯具包会对我们理解它们更加有帮助。
Python
Native Python dates and times: datetime and dateutil
原⽣Python⽇期和时间: datetime 和 dateutil
Python's basic objects for working with dates and times reside in the built-in datetime module. Along with the thirdparty dateutil module, you can use it to quickly perform a host of useful functionalities on dates and times. For
example, you can manually build a date using the datetime type:
最基础的⽇期和时间处理包就是 datetime 。如果加上第三⽅的 dateutil 模块,你就能迅速的对⽇期和时间进⾏许多有⽤的操
作了。例如,你可以⼿动创建⼀个 datetime 对象:
Python
In [1]: from datetime import datetime
datetime(year=2015, month=7, day=4)
Out[1]: datetime.datetime(2015, 7, 4, 0, 0)
Or, using the dateutil module, you can parse dates from a variety of string formats:
或者使⽤ dateutil 模块,你可以从许多不同的字符串格式中解析出 datetime 对象:
In [2]: from dateutil import parser
date = parser.parse("4th of July, 2015")
date
Out[2]: datetime.datetime(2015, 7, 4, 0, 0)
Once you have a datetime object, you can do things like printing the day of the week:
获得 datetime 对象之后,你可以对它进⾏很多操作,包括输出这天是星期⼏:
In [3]: date.strftime('%A')
Out[3]: 'Saturday'
In the final line, we've used one of the standard string format codes for printing dates ( "%A" ), which you can read about
in the strftime section of Python's datetime documentation. Documentation of other useful date utilities can be found in
dateutil's online documentation. A related package to be aware of is pytz , which contains tools for working with the
most migrane-inducing piece of time series data: time zones.
在上⾯的代码中,我们使⽤了标准的字符串格式化编码来打印⽇期( "%A" ),你可以在时间格式化在线⽂档中看到全部的说明。Python
的 datetime 在线⽂档可以参考datetime⽂档。其他很有⽤的⽇期时间⼯具 dateutil 的⽂档可在dateutil在线⽂档找到。还有⼀个值得
注意的第三⽅包是 pytz ,⽤来处理最头痛的时间序列数据:时区。
The power of datetime and dateutil lie in their flexibility and easy syntax: you can use these objects and their
built-in methods to easily perform nearly any operation you might be interested in. Where they break down is when you
wish to work with large arrays of dates and times: just as lists of Python numerical variables are suboptimal compared to
NumPy-style typed numerical arrays, lists of Python datetime objects are suboptimal compared to typed arrays of
encoded dates.
和 dateutil 的强⼤在于它们灵活⽽易懂的语法:你可以使⽤这些对象內建的⽅法就可以完成⼏乎所有你感兴趣的时间操
作。但是当对付⼤量的⽇期时间组成的数组时,它们就⽆法胜任了:就像Python的列表和NumPy的类型数组对⽐⼀样,Python的⽇期时间
对象在这种情况下就⽆法与编码后的⽇期时间数组⽐较了。
datetime
Typed arrays of times: NumPy's datetime64
时间的类型数组:NumPy 的 datetime64
The weaknesses of Python's datetime format inspired the NumPy team to add a set of native time series data type to
NumPy. The datetime64 dtype encodes dates as 64-bit integers, and thus allows arrays of dates to be represented
very compactly. The datetime64 requires a very specific input format:
⽇期时间对象的弱点促使NumPy的开发团队在NumPy中加⼊了优化的时间序列数据类型。 datetime64 数据类型将⽇期时间编码
成了⼀个64位的整数,因此NumPy存储⽇期时间的格式⾮常紧凑。 datetime64 规定了⾮常明确的输⼊格式:
Python
In [4]: import numpy as np
date = np.array('2015-07-04', dtype=np.datetime64)
date
Out[4]: array('2015-07-04', dtype='datetime64[D]')
Once we have this date formatted, however, we can quickly do vectorized operations on it:
然后我们就能⽴刻在这个⽇期数组之上应⽤向量化操作:
In [5]: date + np.arange(12)
Out[5]: array(['2015-07-04', '2015-07-05', '2015-07-06', '2015-07-07',
'2015-07-08', '2015-07-09', '2015-07-10', '2015-07-11',
'2015-07-12', '2015-07-13', '2015-07-14', '2015-07-15'],
dtype='datetime64[D]')
Because of the uniform type in NumPy datetime64 arrays, this type of operation can be accomplished much more
quickly than if we were working directly with Python's datetime objects, especially as arrays get large (we introduced
this type of vectorization in Computation on NumPy Arrays: Universal Functions).
因为NumPy数组中所有元素都具有统⼀的 datetime64 类型,上⾯的向量化操作将会⽐我们使⽤Python的 datetime 对象⾼效许多,特
别是当数组变得很⼤的情况下(我们在使⽤Numpy计算:通⽤函数中详细介绍过)。
One detail of the datetime64 and timedelta64 objects is that they are built on a fundamental time unit. Because
the datetime64 object is limited to 64-bit precision, the range of encodable times is 264 times this fundamental unit. In
other words, datetime64 imposes a trade-off between time resolution and maximum time span.
关于 datetime64 和 timedelta64 对象还有⼀个细节就是它们都是在基本时间单位之上构建的。因为 datetime64 被限制在64位精
度上,因此它可被编码的时间范围就是 乘以相应的时间单位。换⾔之, datetime64 需要在时间精度和最⼤时间间隔之间进⾏取舍。
64
2
For example, if you want a time resolution of one nanosecond, you only have enough information to encode a range of
64
2
nanoseconds, or just under 600 years. NumPy will infer the desired unit from the input; for example, here is a daybased datetime:
例如,如果时间单位是纳秒, datetime64 类型能够编码的时间范围就是 纳秒,不到600年。NumPy可以⾃动从输⼊推断需要的时间
精度(单位);如下⾯是天为单位:
64
2
In [6]: np.datetime64('2015-07-04')
Out[6]: numpy.datetime64('2015-07-04')
Here is a minute-based datetime:
下⾯是分钟为单位:
In [7]: np.datetime64('2015-07-04 12:00')
Out[7]: numpy.datetime64('2015-07-04T12:00')
Notice that the time zone is automatically set to the local time on the computer executing the code. You can force any
desired fundamental unit using one of many format codes; for example, here we'll force a nanosecond-based time:
还需要注意的是,⽇期时间会⾃动按照本地计算机的时间来进⾏设置。你可以通过额外指定时间单位参数来设置你需要的精度;例如,下
⾯使⽤的是纳秒单位:
In [8]: np.datetime64('2015-07-04 12:59:59.50', 'ns')
Out[8]: numpy.datetime64('2015-07-04T12:59:59.500000000')
The following table, drawn from the NumPy datetime64 documentation, lists the available format codes along with the
relative and absolute timespans that they can encode:
下⾯这张表,来⾃NumPy datetime64类型在线⽂档,列出了可⽤的时间单位代码以及其相应的时间范围限制:
代码 含义 时间范围 (相对)
时间范围 (绝对)
Y
年 ± 9.2e18 年 [公元前9.2e18 ⾄ 公元后9.2e18]
M
⽉ ± 7.6e17 年 [公元前7.6e17 ⾄ 公元后7.6e17]
W 星期
± 1.7e17 年
[公元前1.7e17 ⾄ 公元后1.7e17]
D
⽇ ± 2.5e16 年 [公元前2.5e16 ⾄ 公元后2.5e16]
h ⼩时
± 1.0e15 年
[公元前1.0e15 ⾄ 公元后1.0e15]
m 分钟
± 1.7e13 年
[公元前1.7e13 ⾄ 公元后1.7e13]
s
秒 ± 2.9e12 年 [公元前2.9e9 ⾄ 公元后2.9e9]
ms 毫秒
± 2.9e9 年
[公元前2.9e6 ⾄ 公元后2.9e6]
us 微秒
± 2.9e6 年 [公元前290301 ⾄ 公元后294241]
ns 纳秒
± 292 年
[公元后1678 ⾄ 公元后2262]
ps ⽪秒
± 106 天
[公元后1969 ⾄ 公元后1970]
fs ⻜秒
± 2.6 ⼩时
[公元后1969 ⾄ 公元后1970]
as 阿秒
± 9.2 秒
[公元后1969 ⾄ 公元后1970]
For the types of data we see in the real world, a useful default is datetime64[ns] , as it can encode a useful range of
modern dates with a suitably fine precision.
对于我们⽬前真实世界的数据来说,⼀个合适的默认值可以是 datetime64[ns] ,因为它既能包含现代的时间范围,也能提供相当⾼的
时间精度。
Finally, we will note that while the datetime64 data type addresses some of the deficiencies of the built-in Python
datetime type, it lacks many of the convenient methods and functions provided by datetime and especially
dateutil . More information can be found in NumPy's datetime64 documentation.
最后,还要提醒的是,虽然 datetime64 数据类型解决了Python內建 datetime 类型的低效问题,但是它却缺少很多 datetime 特别
是 dateutil 对象提供的很⽅便的⽅法。你可以在NumPy的datetime64在线⽂档中查阅更多相关内容。
Dates and times in pandas: best of both worlds
中的⽇期和时间:兼得所⻓
Pandas
Pandas builds upon all the tools just discussed to provide a Timestamp object, which combines the ease-of-use of
datetime and dateutil with the efficient storage and vectorized interface of numpy.datetime64 . From a group
of these Timestamp objects, Pandas can construct a DatetimeIndex that can be used to index data in a Series
or DataFrame ; we'll see many examples of this below.
在刚才介绍的那些⼯具的基础上构建了 Timestamp 对象,既包含了 datetime 和 dateutil 的简单易⽤,⼜吸收了
的⾼效和向量化操作优点。将这些 Timestamp 对象组合起来之后,Pandas就能构建⼀个 DatetimeIndex ,能
或
当中对数据进⾏索引查找;我们下⾯会看到很多有关的例⼦。
Pandas
numpy.datetime64
Series
DataFrame
在
For example, we can use Pandas tools to repeat the demonstration from above. We can parse a flexibly formatted string
date, and use format codes to output the day of the week:
例如,我们使⽤Pandas⼯具可以重复上⾯的例⼦。我们可以将⼀个灵活表⽰时间的字符串解析成⽇期时间对象,然后⽤时间格式化代码进
⾏格式化输出星期⼏:
In [9]: import pandas as pd
date = pd.to_datetime("4th of July, 2015")
date
Out[9]: Timestamp('2015-07-04 00:00:00')
In [10]: date.strftime('%A')
Out[10]: 'Saturday'
Additionally, we can do NumPy-style vectorized operations directly on this same object:
并且,我们可以将NumPy⻛格的向量化操作直接应⽤在同⼀个对象上:
In [11]: date + pd.to_timedelta(np.arange(12), 'D')
Out[11]: DatetimeIndex(['2015-07-04', '2015-07-05', '2015-07-06', '2015-07-07',
'2015-07-08', '2015-07-09', '2015-07-10', '2015-07-11',
'2015-07-12', '2015-07-13', '2015-07-14', '2015-07-15'],
dtype='datetime64[ns]', freq=None)
In the next section, we will take a closer look at manipulating time series data with the tools provided by Pandas.
下⾯,我们将详细介绍使⽤Pandas提供的⼯具对时间序列进⾏操作的⽅法。
Pandas Time Series: Indexing by Time
时间序列:使⽤时间索引
Pandas
Where the Pandas time series tools really become useful is when you begin to index data by timestamps. For example,
we can construct a Series object that has time indexed data:
对于Pandas时间序列⼯具来说,使⽤时间戳来索引数据,才是真正吸引⼈的地⽅。例如,我们可以创建⼀个 Series 对象具有时间索引
标签:
In [12]: index = pd.DatetimeIndex(['2014-07-04', '2014-08-04',
'2015-07-04', '2015-08-04'])
data = pd.Series([0, 1, 2, 3], index=index)
data
Out[12]: 2014-07-04
2014-08-04
2015-07-04
2015-08-04
dtype: int64
0
1
2
3
Now that we have this data in a Series , we can make use of any of the Series indexing patterns we discussed in
previous sections, passing values that can be coerced into dates:
这样我们就有了⼀个 Series 数据,我们可以将任何 Series 索引的⽅法应⽤到这个对象上,我们可以传⼊参数值,Pandas会⾃动转换
为⽇期时间进⾏操作:
In [13]: data['2014-07-04':'2015-07-04']
Out[13]: 2014-07-04
2014-08-04
2015-07-04
dtype: int64
0
1
2
There are additional special date-only indexing operations, such as passing a year to obtain a slice of all data from that
year:
还有很多有关⽇期的索引⽅式,如下⾯将年作为参数传⼊,会得到⼀个全年数据的切⽚:
In [14]: data['2015']
Out[14]: 2015-07-04
2015-08-04
dtype: int64
2
3
Later, we will see additional examples of the convenience of dates-as-indices. But first, a closer look at the available time
series data structures.
后⾯我们会看到更多使⽤⽇期时间作为索引值的例⼦。⾸先来详细看看时间序列数据的结构。
Pandas Time Series Data Structures
时间序列数据结构
Pandas
This section will introduce the fundamental Pandas data structures for working with time series data:
For time stamps, Pandas provides the Timestamp type. As mentioned before, it is essentially a replacement for
Python's native datetime , but is based on the more efficient numpy.datetime64 data type. The associated
Index structure is DatetimeIndex .
For time Periods, Pandas provides the Period type. This encodes a fixed-frequency interval based on
numpy.datetime64 . The associated index structure is PeriodIndex .
For time deltas or durations, Pandas provides the Timedelta type. Timedelta is a more efficient replacement
for Python's native datetime.timedelta type, and is based on numpy.timedelta64 . The associated index
structure is TimedeltaIndex .
这部分内容会介绍Pandas在处理时间序列数据时候使⽤的基本数据结构:
对于时间戳,Pandas提供了 Timestamp 类型。正如上⾯所述,它可以作为Python原⽣ datetime 类型的替代,但是它是构建在
numpy.datetime64 数据类型之上的。对应的索引结构是 DatetimeIndex 。
对于时间周期,Pandas提供了 Period 类型。它是在 numpy.datetime64 的基础上编码了⼀个固定周期间隔的时间。对应的索引
结构是 PeriodIndex 。
对于时间差或持续时间,Pandas提供了 Timedelta 类型。构建于 numpy.timedelta64 之上,是Python原⽣
datetime.timedelta 类型的⾼性能替代。对应的索引结构是 TimedeltaIndex 。
The most fundamental of these date/time objects are the Timestamp and DatetimeIndex objects. While these
class objects can be invoked directly, it is more common to use the pd.to_datetime() function, which can parse a
wide variety of formats. Passing a single date to pd.to_datetime() yields a Timestamp ; passing a series of dates
by default yields a DatetimeIndex :
上述这些⽇期时间对象中最基础的是 Timestamp 和 DatetimeIndex 对象。虽然这些对象可以直接被创建,但是更通⽤的做法是使⽤
pd.to_datetime() 函数,该函数可以将多种格式的字符串解析成⽇期时间。将⼀个⽇期时间传递给 pd.to_datetime() 会得到⼀
个 Timestamp 对象;将⼀系列的⽇期时间传递过去会得到⼀个 DatetimeIndex 对象:
In [15]: dates = pd.to_datetime([datetime(2015, 7, 3), '4th of July, 2015',
'2015-Jul-6', '07-07-2015', '20150708'])
dates
Out[15]: DatetimeIndex(['2015-07-03', '2015-07-04', '2015-07-06', '2015-07-07',
'2015-07-08'],
dtype='datetime64[ns]', freq=None)
Any DatetimeIndex can be converted to a PeriodIndex with the to_period() function with the addition of a
frequency code; here we'll use 'D' to indicate daily frequency:
任何 DatetimeIndex 对象都能使⽤ to_period() 函数转换成 PeriodIndex 对象,不过需要额外指定⼀个频率的参数码;下⾯我们
使⽤ 'D' 来指定频率为天:
In [16]: dates.to_period('D')
Out[16]: PeriodIndex(['2015-07-03', '2015-07-04', '2015-07-06', '2015-07-07',
'2015-07-08'],
dtype='period[D]', freq='D')
A TimedeltaIndex is created, for example, when a date is subtracted from another:
TimedeltaIndex
对象可以通过⽇期时间相减来创建,例如:
In [17]: dates - dates[0]
Out[17]: TimedeltaIndex(['0 days', '1 days', '3 days', '4 days', '5 days'], dtype='timedelta64[ns]', freq=Non
e)
Regular sequences: pd.date_range()
规则序列: pd.date_range()
To make the creation of regular date sequences more convenient, Pandas offers a few functions for this purpose:
pd.date_range() for timestamps, pd.period_range() for periods, and pd.timedelta_range() for time
deltas. We've seen that Python's range() and NumPy's np.arange() turn a startpoint, endpoint, and optional
stepsize into a sequence. Similarly, pd.date_range() accepts a start date, an end date, and an optional frequency
code to create a regular sequence of dates. By default, the frequency is one day:
提供了三个函数来创建规则的⽇期时间序列, pd.date_range() 来创建时间戳的序列, pd.period_range() 来创建周期的
序列,
来创建时间差的序列。我们都已经学习过Python的 range() 和NumPy的 arange() 了,它们接受开
始点、结束点和可选的步⻓参数来创建序列。同样, pd.date_range() 接受开始⽇期时间、结束⽇期时间和可选的周期码来创建⽇期
时间的规则序列。默认周期为⼀天:
Pandas
pd.timedelta_range()
In [18]: pd.date_range('2015-07-03', '2015-07-10')
Out[18]: DatetimeIndex(['2015-07-03', '2015-07-04', '2015-07-05', '2015-07-06',
'2015-07-07', '2015-07-08', '2015-07-09', '2015-07-10'],
dtype='datetime64[ns]', freq='D')
Alternatively, the date range can be specified not with a start and endpoint, but with a startpoint and a number of periods:
⽽且,⽇期时间的范围不仅能通过结束⽇期时间指定,还能通过开始⽇期时间和⼀个持续值来指定:
In [19]: pd.date_range('2015-07-03', periods=8)
Out[19]: DatetimeIndex(['2015-07-03', '2015-07-04', '2015-07-05', '2015-07-06',
'2015-07-07', '2015-07-08', '2015-07-09', '2015-07-10'],
dtype='datetime64[ns]', freq='D')
The spacing can be modified by altering the freq argument, which defaults to D . For example, here we will construct
a range of hourly timestamps:
⽇期时间的间隔可以通过指定 freq 频率参数来修改,否则默认为天 D 。例如,下⾯创建⼀段以⼩时为间隔单位的时间范围:
In [20]: pd.date_range('2015-07-03', periods=8, freq='H')
Out[20]: DatetimeIndex(['2015-07-03 00:00:00', '2015-07-03 01:00:00',
'2015-07-03 02:00:00', '2015-07-03 03:00:00',
'2015-07-03 04:00:00', '2015-07-03 05:00:00',
'2015-07-03 06:00:00', '2015-07-03 07:00:00'],
dtype='datetime64[ns]', freq='H')
To create regular sequences of Period or Timedelta values, the very similar pd.period_range() and
pd.timedelta_range() functions are useful. Here are some monthly periods:
要创建 Period 或 Timedelta 对象,可以类似的调⽤ pd.period_range() 和 pd.timedelta_range() 函数。下⾯是以⽉为单位
的时间周期序列:
In [21]: pd.period_range('2015-07', periods=8, freq='M')
Out[21]: PeriodIndex(['2015-07', '2015-08', '2015-09', '2015-10', '2015-11', '2015-12',
'2016-01', '2016-02'],
dtype='period[M]', freq='M')
And a sequence of durations increasing by an hour:
下⾯是以⼩时为单位的持续时间序列:
In [22]: pd.timedelta_range(0, periods=10, freq='H')
Out[22]: TimedeltaIndex(['00:00:00', '01:00:00', '02:00:00', '03:00:00', '04:00:00',
'05:00:00', '06:00:00', '07:00:00', '08:00:00', '09:00:00'],
dtype='timedelta64[ns]', freq='H')
All of these require an understanding of Pandas frequency codes, which we'll summarize in the next section.
上述函数都需要我们理解Pandas的频率编码,我们⻢上会介绍它。
Frequencies and Offsets
频率和偏移值
Fundamental to these Pandas time series tools is the concept of a frequency or date offset. Just as we saw the D (day)
and H (hour) codes above, we can use such codes to specify any desired frequency spacing. The following table
summarizes the main codes available:
要使⽤Pandas时间序列⼯具,我们需要理解频率和时间偏移值的概念。就像前⾯我们看到的 D 代表天和 H 代表⼩时⼀样,我们可以使⽤
这类符号码指定需要的频率间隔。下表总结了主要的频率码:
码
说明 码
说明
D
⾃然⽇ B
⼯作⽇
W
周
M ⾃然⽇⽉末 BM ⼯作⽇⽉末
Q ⾃然⽇季末 BQ ⼯作⽇季末
A ⾃然⽇年末 BA ⼯作⽇年末
H
⾃然⼩时 BH ⼯作⼩时
T
分钟
S
秒
L
毫秒
U
微秒
N
纳秒
The monthly, quarterly, and annual frequencies are all marked at the end of the specified period. By adding an S suffix
to any of these, they instead will be marked at the beginning:
上⾯的⽉、季度和年都代表着该时间周期的结束时间。如果在这些码后⾯加上 S 后缀,则代表这些时间周期的起始时间:
码
说明
MS ⾃然⽇⽉初
QS ⾃然⽇季初
AS ⾃然⽇年初
码
说明
BMS ⼯作⽇⽉初
BQS ⼯作⽇季初
BAS ⼯作⽇年初
Additionally, you can change the month used to mark any quarterly or annual code by adding a three-letter month code
as a suffix:
Q-JAN , BQ-FEB , QS-MAR , BQS-APR , etc.
A-JAN , BA-FEB , AS-MAR , BAS-APR , etc.
并且你可以通过在季度或者年的符号码后⾯添加三个字⺟的⽉份缩写来指定周期进⾏分隔的⽉份:
Q-JAN 、 BQ-FEB 、 QS-MAR 、 BQS-APR 等
A-JAN 、 BA-FEB 、 AS-MAR 、 BAS-APR 等
In the same way, the split-point of the weekly frequency can be modified by adding a three-letter weekday code:
W-SUN , W-MON , W-TUE , W-WED , etc.
同样,每周的分隔⽇也可以通过在周符号码后⾯添加三个字⺟的星期⼏缩写来指定:
W-SUN 、 W-MON 、 W-TUE 、 W-WED 等
On top of this, codes can be combined with numbers to specify other frequencies. For example, for a frequency of 2
hours 30 minutes, we can combine the hour ( H ) and minute ( T ) codes as follows:
在此之上,符号码还可以进⾏组合⽤来代表其他的频率。例如要表⽰2⼩时30分钟的频率,我们可以通过将⼩时( H )和分钟( T )的符
号码进⾏组合得到:
In [23]: pd.timedelta_range(0, periods=9, freq="2H30T")
Out[23]: TimedeltaIndex(['00:00:00', '02:30:00', '05:00:00', '07:30:00', '10:00:00',
'12:30:00', '15:00:00', '17:30:00', '20:00:00'],
dtype='timedelta64[ns]', freq='150T')
All of these short codes refer to specific instances of Pandas time series offsets, which can be found in the
pd.tseries.offsets module. For example, we can create a business day offset directly as follows:
上述的这些短的符号码实际上是Pandas时间序列偏移值的对象实例的别名,你可以在 pd.tseries.offsets 模块中找到这些偏移值实
例。例如,我们也可以通过⼀个偏移值对象实例来创建时间序列:
In [24]: from pandas.tseries.offsets import BDay
pd.date_range('2015-07-01', periods=5, freq=BDay())
Out[24]: DatetimeIndex(['2015-07-01', '2015-07-02', '2015-07-03', '2015-07-06',
'2015-07-07'],
dtype='datetime64[ns]', freq='B')
For more discussion of the use of frequencies and offsets, see the "DateOffset" section of the Pandas documentation.
更多有关频率和偏移值的讨论,请参阅Pandas在线⽂档⽇期时间偏移值章节。
Resampling, Shifting, and Windowing
重新取样、移动和窗⼝
The ability to use dates and times as indices to intuitively organize and access data is an important piece of the Pandas
time series tools. The benefits of indexed data in general (automatic alignment during operations, intuitive data slicing
and access, etc.) still apply, and Pandas provides several additional time series-specific operations.
使⽤⽇期和时间作为索引来直观的组织和访问数据的能⼒,是Pandas时间序列⼯具的重要功能。前⾯介绍过的索引的那些通⽤优点(⾃动
对⻬,直观的数据切⽚和访问等)依然有效,⽽且Pandas提供了许多额外的时间序列相关操作。
We will take a look at a few of those here, using some stock price data as an example. Because Pandas was developed
largely in a finance context, it includes some very specific tools for financial data. For example, the accompanying
pandas-datareader package (installable via conda install pandas-datareader ), knows how to import
financial data from a number of available sources, including Yahoo finance, Google Finance, and others. Here we will
load Google's closing price history:
我们会在这⾥介绍其中的⼀些,使⽤股票价格数据作为例⼦。因为Pandas是在⾦融背景基础上发展⽽来的,因此它具有⼀些特别的⾦融数
据相关⼯具。例如, pandas-datareader 包(可以通过 conda install pandas-datareader 进⾏安装)可以被⽤来从许多可⽤
的数据源导⼊⾦融数据,包括Yahoo⾦融,Google⾦融和其他。下⾯我们将载⼊Google的收市价历史数据:
译者注:在新版的 pandas-datareader 中,数据源 google 已经不被⽀持,因此,译者采⽤了 yahoo 数据源。
In [28]: from pandas_datareader import data
goog = data.DataReader('GOOG', start='2004', end='2016',
data_source='yahoo')
goog.head()
Out[28]:
High
Low
Open
Close
Volume
Adj Close
2004-08-19
51.835709
47.800831
49.813286
49.982655
44871300.0
49.982655
2004-08-20
54.336334
50.062355
50.316402
53.952770
22942800.0
53.952770
2004-08-23
56.528118
54.321388
55.168217
54.495735
18342800.0
54.495735
2004-08-24
55.591629
51.591621
55.412300
52.239193
15319700.0
52.239193
2004-08-25
53.798351
51.746044
52.284027
52.802086
9232100.0
52.802086
Date
For simplicity, we'll use just the closing price:
为简单起⻅,我们仅使⽤收市价:
In [29]: goog = goog['Close']
We can visualize this using the plot() method, after the normal Matplotlib setup boilerplate (see Chapter 4):
我们可以使⽤ plot() ⽅法来做出图表,当然之前要先完成Matplotlib的相关初始化⼯作(参⻅第四章):
In [30]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set()
In [31]: goog.plot();
Resampling and converting frequencies
重新采样和改变频率
One common need for time series data is resampling at a higher or lower frequency. This can be done using the
resample() method, or the much simpler asfreq() method. The primary difference between the two is that
resample() is fundamentally a data aggregation, while asfreq() is fundamentally a data selection.
对于时间序列数据来说有⼀个很普遍的需求是对数据根据更⾼或更低的频率进⾏重新取样。这可以通过 resample() ⽅法或更简单的
asfreq() ⽅法来实现。两者的主要区别在于 resample() 主要进⾏数据聚合操作,⽽ asfreq() ⽅法主要进⾏数据选择操作。
Taking a look at the Google closing price, let's compare what the two return when we down-sample the data. Here we will
resample the data at the end of business year:
观察⼀下⾕歌的收市价,让我们来⽐较⼀下使⽤两者对数据进⾏更低频率来采样的情况。下⾯我们对数据进⾏每个⼯作⽇年度进⾏重新取
样:
In [32]: goog.plot(alpha=0.5, style='-')
goog.resample('BA').mean().plot(style=':')
goog.asfreq('BA').plot(style='--');
plt.legend(['input', 'resample', 'asfreq'],
loc='upper left');
Notice the difference: at each point, resample reports the average of the previous year, while asfreq reports the
value at the end of the year.
注意这⾥的区别:在每个点, resample 返回了这⼀个年度的平均值,⽽ asfreq 返回了年末的收市值。
For up-sampling, resample() and asfreq() are largely equivalent, though resample has many more options
available. In this case, the default for both methods is to leave the up-sampled points empty, that is, filled with NA values.
Just as with the pd.fillna() function discussed previously, asfreq() accepts a method argument to specify
how values are imputed. Here, we will resample the business day data at a daily frequency (i.e., including weekends):
对于采⽤更⾼频率的取样来说, resample() 和 asfreq() ⽅法⼤体上是相同的,虽然resample有着更多的参数。在这个例⼦中,默认
的⽅式是将更⾼频率的采样点填充为空值,即NA值。就像之前介绍过的 pd.fillna() 函数那样, asfreq() ⽅法接受⼀个 method 参
数来指定值以那种⽅式插⼊。下⾯,我们将原本数据的⼯作⽇频率扩张为⾃然⽇频率(即包括周末):
In [33]: fig, ax = plt.subplots(2, sharex=True)
data = goog.iloc[:10]
data.asfreq('D').plot(ax=ax[0], marker='o')
data.asfreq('D', method='bfill').plot(ax=ax[1], style='-o')
data.asfreq('D', method='ffill').plot(ax=ax[1], style='--o')
ax[1].legend(["back-fill", "forward-fill"]);
The top panel is the default: non-business days are left as NA values and do not appear on the plot. The bottom panel
shows the differences between two strategies for filling the gaps: forward-filling and backward-filling.
上⾯的⼦图表是默认的:⾮⼯作⽇的数据点被填充为NA值,因此在图中没有显⽰。下⾯的⼦图表展⽰了两种不同填充⽅法的差别:前向填
充和后向填充。
Time-shifts
时间移动
Another common time series-specific operation is shifting of data in time. Pandas has two closely related methods for
computing this: shift() and tshift() In short, the difference between them is that shift() shifts the data,
while tshift() shifts the index. In both cases, the shift is specified in multiples of the frequency.
另⼀个普遍的时间序列相关操作是移动时间。Pandas有两个很接近的⽅法来实现时间的移动: shift() 和 tshift 。简单来说,
shift() 移动的是数据,⽽ tshift() 移动的是时间索引。两个⽅法使⽤的移动参数都是当前频率的倍数。
Here we will both shift() and tshift() by 900 days;
下⾯我们使⽤ shift() 和 tshift() ⽅法将数据和时间索引移动900天:
In [34]: fig, ax = plt.subplots(3, sharey=True)
在数据上应⽤⼀个频率
#
goog = goog.asfreq('D', method='pad')
画出原图
goog.plot(ax=ax[0]) #
goog.shift(900).plot(ax=ax[1]) #
goog.tshift(900).plot(ax=ax[2]) #
数据移动900天
时间移动900天
图例和标签
#
local_max = pd.to_datetime('2007-11-05')
offset = pd.Timedelta(900, 'D')
ax[0].legend(['input'], loc=2)
ax[0].get_xticklabels()[2].set(weight='heavy', color='red')
ax[0].axvline(local_max, alpha=0.3, color='red')
ax[1].legend(['shift(900)'], loc=2)
ax[1].get_xticklabels()[2].set(weight='heavy', color='red')
ax[1].axvline(local_max + offset, alpha=0.3, color='red')
ax[2].legend(['tshift(900)'], loc=2)
ax[2].get_xticklabels()[1].set(weight='heavy', color='red')
ax[2].axvline(local_max + offset, alpha=0.3, color='red');
We see here that shift(900) shifts the data by 900 days, pushing some of it off the end of the graph (and leaving NA
values at the other end), while tshift(900) shifts the index values by 900 days.
上例中,我们看到 shift(900) 将数据向前移动了900天,导致部分数据都超过了图表的右侧范围(左侧新出现的值被填充为NA值),
⽽ tshift(900) 将时间向后移动了900天。
A common context for this type of shift is in computing differences over time. For example, we use shifted values to
compute the one-year return on investment for Google stock over the course of the dataset:
这种时间移动的常⻅应⽤场景是计算同⽐时间段的差值。例如,我们可以将数据时间向前移动365天来计算⾕歌股票的年投资回报率:
In [35]: ROI = 100 * (goog.tshift(-365) / goog - 1)
ROI.plot()
plt.ylabel('% Return on Investment');
This helps us to see the overall trend in Google stock: thus far, the most profitable times to invest in Google have been
(unsurprisingly, in retrospect) shortly after its IPO, and in the middle of the 2009 recession.
这帮助我们看到⾕歌股票的整体趋势:直到⽬前为⽌,投资⾕歌股票回报最⾼的时期(完全不令⼈惊讶)是IPO之后的短暂时期以及2009
中期经济衰退的时期。
Rolling windows
滚动窗⼝
Rolling statistics are a third type of time series-specific operation implemented by Pandas. These can be accomplished
via the rolling() attribute of Series and DataFrame objects, which returns a view similar to what we saw with
the groupby operation (see Aggregation and Grouping). This rolling view makes available a number of aggregation
operations by default.
滚动窗⼝统计是第三种Pandas时间序列相关的普遍操作。这个统计任务可以通过 Series 和 DataFrame 对象的 rolling() ⽅法来实
现,这个⽅法的返回值类似与我们之前看到的 groupby 操作(参⻅聚合与分组)。在该滚动窗⼝视图上可以进⾏⼀系列的聚合操作。
For example, here is the one-year centered rolling mean and standard deviation of the Google stock prices:
例如,下⾯是对⾕歌股票价格在365个记录中居中求平均值和标准差的结果:
In [36]: rolling = goog.rolling(365, center=True) # 对365个交易⽇的收市价进⾏滚动窗⼝居中
data = pd.DataFrame({'input': goog,
'one-year rolling_mean': rolling.mean(), #
'one-year rolling_std': rolling.std()}) #
ax = data.plot(style=['-', '--', ':'])
ax.lines[0].set_alpha(0.3)
平均值Series
标准差Series
As with group-by operations, the aggregate() and apply() methods can be used for custom rolling computations.
和groupby操作⼀样, aggregate() 和 apply() ⽅法可以在滚动窗⼝上实现⾃定义的统计计算。
Where to Learn More
更多学习资源
This section has provided only a brief summary of some of the most essential features of time series tools provided by
Pandas; for a more complete discussion, you can refer to the "Time Series/Date" section of the Pandas online
documentation.
本节只是简要的介绍了Pandas提供的时间序列⼯具中最关键的特性;需要完整的内容介绍,你可以访问Pandas在线⽂档的"时间序列/⽇
期"章节。
Another excellent resource is the textbook Python for Data Analysis by Wes McKinney (OReilly, 2012). Although it is now
a few years old, it is an invaluable resource on the use of Pandas. In particular, this book emphasizes time series tools in
the context of business and finance, and focuses much more on particular details of business calendars, time zones, and
related topics.
还有⼀个很棒的资源是Python for Data Analysis教科书,作者Wes McKinney (OReilly, 2012)。虽然已经出版了好⼏年,这本书仍然是
Pandas使⽤的⾮常有价值的资源。特别是书中着重介绍在商业和⾦融领域中使⽤时间序列相关⼯具的内容,还有许多对商业⽇历,时区等
相关主题的讨论。
As always, you can also use the IPython help functionality to explore and try further options available to the functions and
methods discussed here. I find this often is the best way to learn a new Python tool.
当然别忘了,你可以使⽤IPython的帮助和⽂档功能来学习和尝试这些⼯具⽅法的不同参数。这通常是学习Python⼯具最佳实践。
Example: Visualizing Seattle Bicycle Counts
例⼦:西雅图⾃⾏⻋统计可视化
As a more involved example of working with some time series data, let's take a look at bicycle counts on Seattle's
Fremont Bridge. This data comes from an automated bicycle counter, installed in late 2012, which has inductive sensors
on the east and west sidewalks of the bridge. The hourly bicycle counts can be downloaded from http://data.seattle.gov/;
here is the direct link to the dataset.
最后作为⼀个更深⼊的处理时间序列数据例⼦,我们来看⼀下西雅图费利蒙桥的⾃⾏⻋数量统计。该数据集来源⾃⼀个⾃动⾃⾏⻋的计数
器,在2012年末安装上线,它们能够感应到桥上东西双向通过的⾃⾏⻋并进⾏计数。按照⼩时频率采样的⾃⾏⻋数量计数数据集可以在这
个链接处直接下载。
As of summer 2016, the CSV can be downloaded as follows:
年夏天的数据可以使⽤下⾯的命令下载:
2016
In [34]: # !curl -o FremontBridge.csv https://data.seattle.gov/api/views/65db-xm6k/rows.csv?accessType=DOWNLO
AD
Once this dataset is downloaded, we can use Pandas to read the CSV output into a DataFrame . We will specify that
we want the Date as an index, and we want these dates to be automatically parsed:
下载了数据集后,我们就可以⽤Pandas将CSV⽂件的内容导⼊成 DataFrame 对象。我们指定使⽤⽇期作为⾏索引,还可以通过
parse_dates 参数要求Pandas⾃动帮我们转换⽇期时间格式:
In [43]: data = pd.read_csv('data/FremontBridge.csv', index_col='Date', parse_dates=True)
data.head()
Out[43]:
Fremont Bridge Total
Fremont Bridge East Sidewalk
Fremont Bridge West Sidewalk
2012-10-03 00:00:00
13.0
4.0
9.0
2012-10-03 01:00:00
10.0
4.0
6.0
2012-10-03 02:00:00
2.0
1.0
1.0
2012-10-03 03:00:00
5.0
2.0
3.0
2012-10-03 04:00:00
7.0
6.0
1.0
Date
For convenience, we'll further process this dataset by shortening the column names and adding a "Total" column:
为了简单,我们将这个数据集的列名改的简短些,并增加总计“Total”列:
译者注:最新下载的数据集⾃带Total列,因此,只需要缩短列名即可,下⾯的代码译者注释了原来的代码,并使⽤⼀⾏代码将列名缩短。
In [44]: # data.columns = ['West', 'East']
# data['Total'] = data.eval('West + East')
data.columns = ['Total', 'East', 'West']
Now let's take a look at the summary statistics for this data:
现在我们来看看这个数据集的总体情况:
In [45]: data.dropna().describe()
Out[45]:
Total
East
West
count
10771.000000
10771.000000
10771.000000
mean
99.713861
51.416489
48.297373
std
120.397155
63.867062
67.568734
min
0.000000
0.000000
0.000000
25%
15.000000
7.000000
7.000000
50%
57.000000
29.000000
26.000000
75%
134.000000
69.000000
60.000000
max
831.000000
626.000000
593.000000
Visualizing the data
可视化数据
We can gain some insight into the dataset by visualizing it. Let's start by plotting the raw data:
我们可以通过将数据可视化成图表来更好的观察分析数据集。⾸先我们来展⽰原始数据图表:
In [46]: %matplotlib inline
import seaborn; seaborn.set()
In [47]: data.plot()
plt.ylabel('Hourly Bicycle Count');
The ~25,000 hourly samples are far too dense for us to make much sense of. We can gain more insight by resampling
the data to a coarser grid. Let's resample by week:
约25000⼩时的样本数据画在图中⾮常拥挤,我们很观察到什么有意义的结果。我们可以通过重新取样,降低频率来获得更粗颗粒度的图
像。如下⾯按照每周来重新取样:
In [48]: weekly = data.resample('W').sum()
weekly.plot(style=[':', '--', '-'])
plt.ylabel('Weekly bicycle count');
This shows us some interesting seasonal trends: as you might expect, people bicycle more in the summer than in the
winter, and even within a particular season the bicycle use varies from week to week (likely dependent on weather; see In
Depth: Linear Regression where we explore this further).
上图向我们展⽰⾮常有趣的季节性趋势:你应该已经预料到,⼈们在夏季会⽐冬季更多的骑⾃⾏⻋,即使在⼀个季节中,每周⾃⾏⻋的数
量也有很⼤起伏(这主要是由于天⽓造成的;我们会在深⼊:线性回归中会更加深⼊的讨论)。
Another way that comes in handy for aggregating the data is to use a rolling mean, utilizing the pd.rolling_mean()
function. Here we'll do a 30 day rolling mean of our data, making sure to center the window:
还有⼀个很⽅便的聚合操作就是滚动平均值,使⽤ pd.rolling_mean() 函数。下⾯我们进⾏30天的滚动平均,窗⼝居中进⾏统计:
In [49]: daily = data.resample('D').sum()
daily.rolling(30, center=True).sum().plot(style=[':', '--', '-'])
plt.ylabel('mean hourly count');
The jaggedness of the result is due to the hard cutoff of the window. We can get a smoother version of a rolling mean
using a window function–for example, a Gaussian window. The following code specifies both the width of the window (we
chose 50 days) and the width of the Gaussian within the window (we chose 10 days):
上图结果中的锯⻮图案产⽣的原因是窗⼝边缘的硬切割造成的。我们可以使⽤不同的窗⼝类型来获得更加平滑的结果,例如⾼斯窗⼝。下
⾯的代码制定了窗⼝的宽度(50天)和窗⼝内的⾼斯宽度(10天):
In [51]: daily.rolling(50, center=True,
win_type='gaussian').sum(std=10).plot(style=[':', '--', '-']);
Digging into the data
挖掘数据
While these smoothed data views are useful to get an idea of the general trend in the data, they hide much of the
interesting structure. For example, we might want to look at the average traffic as a function of the time of day. We can do
this using the GroupBy functionality discussed in Aggregation and Grouping:
虽然上⾯的光滑折线图展⽰了⼤体的数据趋势情况,但是很多有趣的结构依然没有展现出来。例如,我们希望对每天不同时段的平均交通
情况进⾏统计,我们可以使⽤聚合与分组中介绍过的GroupBy功能:
In [52]: by_time = data.groupby(data.index.time).mean()
hourly_ticks = 4 * 60 * 60 * np.arange(6) #
24
4
by_time.plot(xticks=hourly_ticks, style=[':', '--', '-']);
将 ⼩时分为每 个⼩时⼀段展⽰
The hourly traffic is a strongly bimodal distribution, with peaks around 8:00 in the morning and 5:00 in the evening. This is
likely evidence of a strong component of commuter traffic crossing the bridge. This is further evidenced by the differences
between the western sidewalk (generally used going toward downtown Seattle), which peaks more strongly in the
morning, and the eastern sidewalk (generally used going away from downtown Seattle), which peaks more strongly in the
evening.
⼩时交通数据图展现了明显的双峰构造,峰值⼤约出现在早上8:00和下午5:00。这显然就是⼤桥在通勤时间交通繁忙的最好证据。再注意
到东西双向峰值不同,证明了早上通勤时间多数的交通流量是从东⾄西(往西雅图城中⼼⽅向),⽽下午通勤时间多数的交通流量是从西
⾄东(离开西雅图城中⼼⽅向)。
We also might be curious about how things change based on the day of the week. Again, we can do this with a simple
groupby:
我们可能也会很好奇⼀周中每天的平均交通情况。当然,还是通过简单的GroupBy就能实现:
In [53]: by_weekday = data.groupby(data.index.dayofweek).mean()
by_weekday.index = ['Mon', 'Tues', 'Wed', 'Thurs', 'Fri', 'Sat', 'Sun']
by_weekday.plot(style=[':', '--', '-']);
This shows a strong distinction between weekday and weekend totals, with around twice as many average riders
crossing the bridge on Monday through Friday than on Saturday and Sunday.
上图清晰的展⽰了⼯作⽇和休息⽇的区别,周⼀到周五的流量基本上达到周六⽇的两倍。
With this in mind, let's do a compound GroupBy and look at the hourly trend on weekdays versus weekends. We'll start
by grouping by both a flag marking the weekend, and the time of day:
有了上⾯两个分析的基础,让我们来进⾏⼀个更加复杂的分组查看⼯作⽇和休息⽇按照⼩时交通流量的情况。我们⾸先使⽤ np.where 将
⼯作⽇和休息⽇分开:
In [54]: weekend = np.where(data.index.weekday < 5, 'Weekday', 'Weekend')
by_time = data.groupby([weekend, data.index.time]).mean()
Now we'll use some of the Matplotlib tools described in Multiple Subplots to plot two panels side by side:
然后我们使⽤将在多个⼦图表中介绍的⽅法将两个⼦图表并排展⽰:
译者注:因为 DataFrame.ix 已经不推荐使⽤,因此下⾯代码中的索引符改成了loc。
In [56]: import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(14, 5))
by_time.loc['Weekday'].plot(ax=ax[0], title='Weekdays',
xticks=hourly_ticks, style=[':', '--', '-'])
by_time.loc['Weekend'].plot(ax=ax[1], title='Weekends',
xticks=hourly_ticks, style=[':', '--', '-']);
The result is very interesting: we see a bimodal commute pattern during the work week, and a unimodal recreational
pattern during the weekends. It would be interesting to dig through this data in more detail, and examine the effect of
weather, temperature, time of year, and other factors on people's commuting patterns; for further discussion, see my blog
post "Is Seattle Really Seeing an Uptick In Cycling?", which uses a subset of this data. We will also revisit this dataset in
the context of modeling in In Depth: Linear Regression.
这个结果⾮常有趣:我们可以在⼯作⽇看到明显的双峰构造,但是在休息⽇就只能看到⼀个峰。如果我们继续挖掘下去,这个数据集还有
更多有趣的结构可以被发现,可以分析天⽓、⽓温、每年的不同时间以及其他因素是如何影响居⺠的通勤⽅式的;要深⼊讨论,可以参⻅
作者的博客⽂章"Is Seattle Really Seeing an Uptick In Cycling?",⾥⾯使⽤了这个数据集的⼦集。我们也会在深⼊:线性回归⼩节中再次
遇到这个数据集。
<
向量化的字符串操作 | ⽬录 | ⾼性能Pandas: eval() 和 query() >
Open in Colab
<
在时间序列上操作 | ⽬录 | 更多资源 >
Open in Colab
High-Performance Pandas: eval() and query()
⾼性能Pandas: eval() 和 query()
As we've already seen in previous sections, the power of the PyData stack is built upon the ability of NumPy and Pandas
to push basic operations into C via an intuitive syntax: examples are vectorized/broadcasted operations in NumPy, and
grouping-type operations in Pandas. While these abstractions are efficient and effective for many common use cases,
they often rely on the creation of temporary intermediate objects, which can cause undue overhead in computational time
and memory use.
前⾯的章节中,我们已经了解了PyData的整个技术栈建⽴在NumPy和Pandas能将基础的向量化运算使⽤C底层的⽅式实现,语法却依然
保持简单和直观:例⼦包括NumPy中的向量化和⼴播操作,及Pandas的分组类型的操作。虽然这些抽象在很多通⽤场合下是⾮常⾼效
的,但是这些操作都涉及到创建临时对象,仍然会产⽣额外的计算时间和内存占⽤。
As of version 0.13 (released January 2014), Pandas includes some experimental tools that allow you to directly access
C-speed operations without costly allocation of intermediate arrays. These are the eval() and query() functions,
which rely on the Numexpr package. In this notebook we will walk through their use and give some rules-of-thumb about
when you might think about using them.
在 版本(2014年1⽉发布)加⼊了⼀些实验性的⼯具,能直接进⾏C底层的运算⽽不需要创建临时的数组。函数 eval() 和
具有这个特性,底层是基于Numexpr包构建的。在本节中,我们会简单介绍它们的使⽤,然后给出何时适合使⽤它们的基础规
Pandas 0.13
query()
则。
Motivating query() and eval() : Compound Expressions
使⽤ query() 和 eval() :复合表达式
We've seen previously that NumPy and Pandas support fast vectorized operations; for example, when adding the
elements of two arrays:
我们已经掌握了NumPy和Pandas能够⽀持快速向量化操作;例如,当将两个数组进⾏加法操作时:
In [1]: import numpy as np
rng = np.random.RandomState(42)
x = rng.rand(1000000)
y = rng.rand(1000000)
%timeit x + y
2.04 ms ± 62.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
As discussed in Computation on NumPy Arrays: Universal Functions, this is much faster than doing the addition via a
Python loop or comprehension:
我们在使⽤Numpy计算:通⽤函数中已经讨论过,这种运算对⽐使⽤Python循环或列表解析的⽅法要⾼效的多:
In [2]: %timeit np.fromiter((xi + yi for xi, yi in zip(x, y)), dtype=x.dtype, count=len(x))
186 ms ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
But this abstraction can become less efficient when computing compound expressions. For example, consider the
following expression:
但是当运算变得复杂的情况下,这种向量化运算就会变得没那么⾼效了。如下例:
In [3]: mask = (x > 0.5) & (y < 0.5)
Because NumPy evaluates each subexpression, this is roughly equivalent to the following:
因为NumPy会独⽴计算每⼀个⼦表达式,因此上⾯代码等同与下⾯:
In [4]: tmp1 = (x > 0.5)
tmp2 = (y < 0.5)
mask = tmp1 & tmp2
In other words, every intermediate step is explicitly allocated in memory. If the x and y arrays are very large, this can
lead to significant memory and computational overhead. The Numexpr library gives you the ability to compute this type of
compound expression element by element, without the need to allocate full intermediate arrays. The Numexpr
documentation has more details, but for the time being it is sufficient to say that the library accepts a string giving the
NumPy-style expression you'd like to compute:
换⾔之,每个中间步骤都会显式分配内存。如果 x 和 y 数组变得⾮常巨⼤,这会带来显著的内存和计算资源开销。Numexpr库提供了既
能使⽤简单语法进⾏数组的逐元素运算的能⼒,⼜不需要为中间步骤数组分配全部内存的能⼒。Numexpr在线⽂档中有更加详细的说明,
我们现在只需要将它理解为这个库能接受⼀个NumPy⻛格的表达式字符串,然后计算得到结果:
In [5]: import numexpr
mask_numexpr = numexpr.evaluate('(x > 0.5) & (y < 0.5)')
np.allclose(mask, mask_numexpr)
Out[5]: True
The benefit here is that Numexpr evaluates the expression in a way that does not use full-sized temporary arrays, and
thus can be much more efficient than NumPy, especially for large arrays. The Pandas eval() and query() tools
that we will discuss here are conceptually similar, and depend on the Numexpr package.
这样做的优点是,Numexpr使⽤的临时数组不是完全分配空间的,并利⽤这少量数组即能完成计算,因此能⽐NumPy更加⾼效,特别是对
⼤的数组来说。我们将会讨论到的Pandas的 eval() 和 query ⼯具,就是基于Numexpr包构建的。
pandas.eval() for Efficient Operations
pandas.eval()
更加⾼效的运算
The eval() function in Pandas uses string expressions to efficiently compute operations using DataFrame s. For
example, consider the following DataFrame s:
Pandas
中的 eval() 函数可以使⽤字符串类型的表达式对 DataFrame 进⾏运算。例如,创建下⾯的 DataFrame :
In [6]: import pandas as pd
nrows, ncols = 100000, 100
rng = np.random.RandomState(42)
df1, df2, df3, df4 = (pd.DataFrame(rng.rand(nrows, ncols))
for i in range(4))
To compute the sum of all four DataFrame s using the typical Pandas approach, we can just write the sum:
要计算所有四个 DataFrame 的总和,使⽤典型的Pandas⽅式,我们只需要将它们相加:
In [7]: %timeit df1 + df2 + df3 + df4
72.2 ms ± 8.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The same result can be computed via pd.eval by constructing the expression as a string:
我们也可以使⽤ pd.eval ,参数传⼊上述表达式的字符串形式,计算得到同样的结果:
In [8]: %timeit pd.eval('df1 + df2 + df3 + df4')
35 ms ± 955 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
The eval() version of this expression is about 50% faster (and uses much less memory), while giving the same result:
eval()
版本的计算⽐典型⽅法快了接近接近50%(⽽且使⽤了更少的内存),我们来使⽤ np.allclose() 函数验证⼀下结果是否相
同:
译者注:50%是按照原⽂翻译的,在译者⾃⼰笔记本上 eval() 的运⾏时间是典型⽅式的不到⼀半,运算速度应该是提⾼了100%多。
In [9]: np.allclose(df1 + df2 + df3 + df4,
pd.eval('df1 + df2 + df3 + df4'))
Out[9]: True
Operations supported by pd.eval()
pd.eval()
⽀持的运算
As of Pandas v0.16, pd.eval() supports a wide range of operations. To demonstrate these, we'll use the following
integer DataFrame s:
到了Pandas 0.16版本, pd.eval() ⽀持很⼤范围的运算。我们使⽤下⾯的整数 DataFrame 来进⾏展⽰:
In [10]: df1, df2, df3, df4, df5 = (pd.DataFrame(rng.randint(0, 1000, (100, 3)))
for i in range(5))
Arithmetic operators
算术运算
pd.eval() supports all arithmetic operators. For example:
pd.eval()
⽀持所有的算术运算。例如:
In [11]: result1 = -df1 * df2 / (df3 + df4) - df5
result2 = pd.eval('-df1 * df2 / (df3 + df4) - df5')
np.allclose(result1, result2)
Out[11]: True
Comparison operators
⽐较运算
pd.eval() supports all comparison operators, including chained expressions:
pd.eval()
⽀持所有的⽐较运算,包括链式表达式:
In [12]: result1 = (df1 < df2) & (df2 <= df3) & (df3 != df4)
result2 = pd.eval('df1 < df2 <= df3 != df4')
np.allclose(result1, result2)
Out[12]: True
Bitwise operators
位运算
pd.eval() supports the & and | bitwise operators:
⽀持与 & 以及或 | 位运算符:
译者注:还⽀持⾮ ~ 位运算符。
pd.eval()
In [13]: result1 = (df1 < 0.5) & (df2 < 0.5) | (df3 < df4)
result2 = pd.eval('(df1 < 0.5) & (df2 < 0.5) | (df3 < df4)')
np.allclose(result1, result2)
Out[13]: True
In addition, it supports the use of the literal and and or in Boolean expressions:
⽽且,(译者注:对⽐NumPy)它还⽀持Python的在布尔表达式中使⽤逻辑运算 and 和 or :
译者注:还⽀持 not 逻辑运算。
In [14]: result3 = pd.eval('(df1 < 0.5) and (df2 < 0.5) or (df3 < df4)')
np.allclose(result1, result3)
Out[14]: True
Object attributes and indices
对象属性和索引
pd.eval() supports access to object attributes via the obj.attr syntax, and indexes via the obj[index]
syntax:
pd.eval()
⽀持使⽤ obj.attr 语法获取对象属性,也⽀持使⽤ obj[index] 语法进⾏索引:
In [15]: result1 = df2.T[0] + df3.iloc[1]
result2 = pd.eval('df2.T[0] + df3.iloc[1]')
np.allclose(result1, result2)
Out[15]: True
Other operations
其他运算
Other operations such as function calls, conditional statements, loops, and other more involved constructs are currently
not implemented in pd.eval() . If you'd like to execute these more complicated types of expressions, you can use the
Numexpr library itself.
其他运算例如函数调⽤、条件语句、循环以及其他混合结构⽬前都不被 pd.eval() ⽀持。如果你需要使⽤这种复杂的表达式,你可以使
⽤Numexpr库本⾝。
DataFrame.eval() for Column-Wise Operations
操作列
DataFrame.eval()
Just as Pandas has a top-level pd.eval() function, DataFrame s have an eval() method that works in similar
ways. The benefit of the eval() method is that columns can be referred to by name. We'll use this labeled array as an
example:
有着顶层的 pd.eval() 函数, DataFrame 也有⾃⼰的 eval() ⽅法,实现的功能类似。使⽤ eval() ⽅法的好处是可以使⽤
列名指代列。我们使⽤下⾯的带列标签的数组作为例⼦说明:
Pandas
In [16]: df = pd.DataFrame(rng.rand(1000, 3), columns=['A', 'B', 'C'])
df.head()
Out[16]:
A
B
C
0
0.375506
0.406939
0.069938
1
0.069087
0.235615
0.154374
2
0.677945
0.433839
0.652324
3
0.264038
0.808055
0.347197
4
0.589161
0.252418
0.557789
Using pd.eval() as above, we can compute expressions with the three columns like this:
使⽤上⾯的 pd.eval() ,我们可以如下计算三个列的结果:
In [17]: result1 = (df['A'] + df['B']) / (df['C'] - 1)
result2 = pd.eval("(df.A + df.B) / (df.C - 1)")
np.allclose(result1, result2)
Out[17]: True
The DataFrame.eval() method allows much more succinct evaluation of expressions with the columns:
使⽤ DataFrame.eval() ⽅法允许我们采⽤更加直接的⽅式操作列数据:
In [18]: result3 = df.eval('(A + B) / (C - 1)')
np.allclose(result1, result3)
Out[18]: True
Notice here that we treat column names as variables within the evaluated expression, and the result is what we would
wish.
上⾯的代码中我们在表达式中将列名作为变量来使⽤,⽽且结果也是⼀致的。
Assignment in DataFrame.eval()
DataFrame.eval()
中的赋值
In addition to the options just discussed, DataFrame.eval() also allows assignment to any column. Let's use the
DataFrame from before, which has columns 'A' , 'B' , and 'C' :
除了上⾯的操作外, DataFrame.eval() 也⽀持对任何列的赋值操作。还是使⽤上⾯的 DataFrame ,有着 A 、 B 和 C 三个列:
In [19]: df.head()
Out[19]:
A
B
C
0
0.375506
0.406939
0.069938
1
0.069087
0.235615
0.154374
2
0.677945
0.433839
0.652324
3
0.264038
0.808055
0.347197
4
0.589161
0.252418
0.557789
We can use df.eval() to create a new column 'D' and assign to it a value computed from the other columns:
我们可以使⽤ df.eval() ⽅法类创建⼀个新的列 'D' ,然后将它赋值为其他列运算结果:
In [20]: df.eval('D = (A + B) / C', inplace=True)
df.head()
Out[20]:
A
B
C
D
0
0.375506
0.406939
0.069938
11.187620
1
0.069087
0.235615
0.154374
1.973796
2
0.677945
0.433839
0.652324
1.704344
3
0.264038
0.808055
0.347197
3.087857
4
0.589161
0.252418
0.557789
1.508776
In the same way, any existing column can be modified:
同样的,已经存在的列可以被修改:
In [21]: df.eval('D = (A - B) / C', inplace=True)
df.head()
Out[21]:
A
B
C
D
0
0.375506
0.406939
0.069938
-0.449425
1
0.069087
0.235615
0.154374
-1.078728
2
0.677945
0.433839
0.652324
0.374209
3
0.264038
0.808055
0.347197
-1.566886
4
0.589161
0.252418
0.557789
0.603708
Local variables in DataFrame.eval()
DataFrame.eval()
中的本地变量
The DataFrame.eval() method supports an additional syntax that lets it work with local Python variables. Consider
the following:
DataFrame.eval()
⽅法还⽀持使⽤脚本中的本地Python变量。⻅下例:
In [22]: column_mean = df.mean(1)
result1 = df['A'] + column_mean
result2 = df.eval('A + @column_mean')
np.allclose(result1, result2)
Out[22]: True
The @ character here marks a variable name rather than a column name, and lets you efficiently evaluate expressions
involving the two "namespaces": the namespace of columns, and the namespace of Python objects. Notice that this @
character is only supported by the DataFrame.eval() method, not by the pandas.eval() function, because the
pandas.eval() function only has access to the one (Python) namespace.
上⾯的字符串表达式中的 @ 符号表⽰的是⼀个变量名称⽽不是⼀个列名,这个表达式能⾼效的计算涉及列空间和Python对象空间的运算表
达式。需要注意的是 @ 符号只能在 DataFrame.eval() ⽅法中使⽤,不能在 pandas.eval() 函数中使⽤,因为 pandas.eval()
实际上只有⼀个命名空间。
DataFrame.query() Method
DataFrame.query()
⽅法
The DataFrame has another method based on evaluated strings, called the query() method. Consider the
following:
DataFrame
还有另外⼀个⽅法也是建⽴在字符串表达式运算的基础上的,就是 query() 。看下⾯这个例⼦:
In [23]: result1 = df[(df.A < 0.5) & (df.B < 0.5)]
result2 = pd.eval('df[(df.A < 0.5) & (df.B < 0.5)]')
np.allclose(result1, result2)
Out[23]: True
As with the example used in our discussion of DataFrame.eval() , this is an expression involving columns of the
DataFrame . It cannot be expressed using the DataFrame.eval() syntax, however! Instead, for this type of filtering
operation, you can use the query() method:
根据前⾯的例⼦和讨论,这是⼀个涉及 DataFrame 列的表达式。但是它却不能使⽤ DataFrame.eval() 来实现。在这种情况下,你可
以使⽤ query() ⽅法:
In [24]: result2 = df.query('A < 0.5 and B < 0.5')
np.allclose(result1, result2)
Out[24]: True
In addition to being a more efficient computation, compared to the masking expression this is much easier to read and
understand. Note that the query() method also accepts the @ flag to mark local variables:
除了提供更加⾼效的计算外,这种语法⽐遮盖数组的⽅式更加容易读明⽩。⽽且 query() ⽅法也接受 @ 符号来标记本地变量:
In [25]: Cmean = df['C'].mean()
result1 = df[(df.A < Cmean) & (df.B < Cmean)]
result2 = df.query('A < @Cmean and B < @Cmean')
np.allclose(result1, result2)
Out[25]: True
Performance: When to Use These Functions
性能:什么时候选择使⽤这些函数
When considering whether to use these functions, there are two considerations: computation time and memory use.
Memory use is the most predictable aspect. As already mentioned, every compound expression involving NumPy arrays
or Pandas DataFrame s will result in implicit creation of temporary arrays: For example, this:
是否使⽤这些函数主要取决与两个考虑:计算时间和内存占⽤。其中最易预测的是内存使⽤。我们之前已经提到,每个基于NumPy数组的
复合表达式都会在每个中间步骤产⽣⼀个临时数组,例如:
In [26]: x = df[(df.A < 0.5) & (df.B < 0.5)]
Is roughly equivalent to this:
等同于:
In [27]: tmp1 = df.A < 0.5
tmp2 = df.B < 0.5
tmp3 = tmp1 & tmp2
x = df[tmp3]
If the size of the temporary DataFrame s is significant compared to your available system memory (typically several
gigabytes) then it's a good idea to use an eval() or query() expression. You can check the approximate size of
your array in bytes using this:
如果产⽣的临时的 DataFrame 与你可⽤的系统内存容量在同⼀个量级(如数GB)的话,那么使⽤ eval() 或者 query() 表达式显然
是个好主意。可以通过数组的nbytes属性查看⼤概的内存占⽤:
In [28]: df.values.nbytes
Out[28]: 32000
On the performance side, eval() can be faster even when you are not maxing-out your system memory. The issue is
how your temporary DataFrame s compare to the size of the L1 or L2 CPU cache on your system (typically a few
megabytes in 2016); if they are much bigger, then eval() can avoid some potentially slow movement of values
between the different memory caches. In practice, I find that the difference in computation time between the traditional
methods and the eval / query method is usually not significant–if anything, the traditional method is faster for smaller
arrays! The benefit of eval / query is mainly in the saved memory, and the sometimes cleaner syntax they offer.
⾄于计算时间考虑, eval() 即使在不考虑内存占⽤的情况下也可能会更快。造成这个差异的原因主要在于临时的 DataFrame 的⼤⼩与
计算机CPU的L1和L2缓存⼤⼩(在2016年通常是⼏个MB)的⽐值;如果缓存相⽐⽽⾔⾜够⼤的话,那么 eval() 可以避免在内存和CPU
缓存之间的数据复制开销。在实践中,作者发现使⽤传统⽅式和 eval / query ⽅法之间的计算时间差异通常很⼩,如果存在的话,传统
⽅法在⼩尺⼨数组的情况下甚⾄还更快。因此 eval / query 的优势主要在于节省内存和它们的语法会更加清晰易懂。
We've covered most of the details of eval() and query() here; for more information on these, you can refer to the
Pandas documentation. In particular, different parsers and engines can be specified for running these queries; for details
on this, see the discussion within the "Enhancing Performance" section.
我们在本节讨论了 eval() 和 query() 的⼤部分内容;要获取更多相关资源,请参考Pandas的在线⽂档。特别的,其他不同的解析器和
引擎也可以指定运⾏这些表达式和查询;有关内容参⻅性能增强章节中的说明。
<
在时间序列上操作 | ⽬录 | 更多资源 >
Open in Colab
<
⾼性能Pandas: eval() 和 query() | ⽬录 | 使⽤matplotlib展⽰数据 >
Further Resources
更多资源
In this chapter, we've covered many of the basics of using Pandas effectively for data analysis. Still, much has been
omitted from our discussion. To learn more about Pandas, I recommend the following resources:
Pandas online documentation: This is the go-to source for complete documentation of the package. While the
examples in the documentation tend to be small generated datasets, the description of the options is complete and
generally very useful for understanding the use of various functions.
Python for Data Analysis Written by Wes McKinney (the original creator of Pandas), this book contains much more
detail on the Pandas package than we had room for in this chapter. In particular, he takes a deep dive into tools for
time series, which were his bread and butter as a financial consultant. The book also has many entertaining
examples of applying Pandas to gain insight from real-world datasets. Keep in mind, though, that the book is now
several years old, and the Pandas package has quite a few new features that this book does not cover (but be on the
lookout for a new edition in 2017).
Stack Overflow: Pandas has so many users that any question you have has likely been asked and answered on
Stack Overflow. Using Pandas is a case where some Google-Fu is your best friend. Simply go to your favorite search
engine and type in the question, problem, or error you're coming across–more than likely you'll find your answer on a
Stack Overflow page.
Pandas on PyVideo: From PyCon to SciPy to PyData, many conferences have featured tutorials from Pandas
developers and power users. The PyCon tutorials in particular tend to be given by very well-vetted presenters.
Using these resources, combined with the walk-through given in this chapter, my hope is that you'll be poised to use
Pandas to tackle any data analysis problem you come across!
在本章中,我们介绍了许多使⽤Pandas有效进⾏数据分析的基础知识。但是显然还有很多内容没有讨论到,需要学习更多有关Pandas的
内容,作者建议阅读下⾯的资源:
Pandas在线⽂档:Pandas包最完整的⽂档来源。虽然⽂档中的例⼦基本上都是⼀些⽣成的⼩数据集,但是⾥⾯的参数说明是很完整
的,⽽且通常对于理解使⽤Pandas的函数和⽅法是⾮常有帮助的。
Python for Data Analysis:作者Wes McKinney(Pandas的创始⼈),这本书包括了很多Pandas的详尽资料。特别是对于时间序列数
据的处理,本书进⾏了深⼊的介绍,这对于经济学的分析是⾮常有帮助的。本书也有很多在真实世界数据中应⽤Pandas的有趣例⼦。
不过此书已经出版有⼏年时间了,因此近⼏年Pandas提供的新特性都没有包括其中,不过我们可以期待本书2017年的新版。
Stack Overflow:Pandas有着许多的⽤⼾,因此你遇到的问题很可能在Stack Overflow⽹站上已经有⼈问过和解答了。当然⾕歌⼤法
也是你的好朋友(译者注:度娘⼤法就免了)。去搜索引擎输⼊你的问题或错误,很可能⾕歌就会将你导向到Stack Overflow的答
案。
PyVideo上的Pandas视频:从PyCon到SciPy再到PyData,很多研讨会都有关于Pandas的教程和专题。特别是PyCon上⾯有很多知名
开发者做的教程和专题讨论。
结合本章介绍的内容,加上上⾯的资源,读者应该能够获得⾜够的帮助⽤来解决使⽤Pandas处理数据分析的所有问题。
<
⾼性能Pandas: eval() 和 query() | ⽬录 | 使⽤matplotlib展⽰数据 >
<
更多资源 | ⽬录 | 简单的折线图 >
Open in Colab
Visualization with Matplotlib
使⽤Matplotlib进⾏可视化
We'll now take an in-depth look at the Matplotlib package for visualization in Python. Matplotlib is a multi-platform data
visualization library built on NumPy arrays, and designed to work with the broader SciPy stack. It was conceived by John
Hunter in 2002, originally as a patch to IPython for enabling interactive MATLAB-style plotting via gnuplot from the
IPython command line. IPython's creator, Fernando Perez, was at the time scrambling to finish his PhD, and let John
know he wouldn’t have time to review the patch for several months. John took this as a cue to set out on his own, and the
Matplotlib package was born, with version 0.1 released in 2003. It received an early boost when it was adopted as the
plotting package of choice of the Space Telescope Science Institute (the folks behind the Hubble Telescope), which
financially supported Matplotlib’s development and greatly expanded its capabilities.
本章中我们会介绍在Python中使⽤Matplotlib包进⾏可视化的知识。Matplotlib是⼀个基于NumPy数组构建的多平台数据可视化程序库,在
SciPy技术栈中被⼴泛使⽤。Matplotlib是John Hunter在2002年开始构思,最早时候是作为IPython的⼀个补充,通过gnuplot⽤来在IPython
命令⾏中实现MATLAB⻛格的交互式的图表展⽰。IPython的作者Fernando Perez那时正在忙于完成他的博⼠学位,John意识到他可能在
⼏个⽉内都⽆法抽出时间来检查这个补丁的代码。于是John决定将这个构思实现成独⽴的软件包,于是MatPlotlib诞⽣了,0.1版本发布于
2003年。最早时候在空间望远镜研究院(也就是哈勃望远镜背后的⼯作⼩组)得到应⽤来绘图,从中获得了很⼤的推动⼒,研究院从经济
上⽀持Matplotlib项⽬的发展并极⼤的扩展了其功能。
One of Matplotlib’s most important features is its ability to play well with many operating systems and graphics backends.
Matplotlib supports dozens of backends and output types, which means you can count on it to work regardless of which
operating system you are using or which output format you wish. This cross-platform, everything-to-everyone approach
has been one of the great strengths of Matplotlib. It has led to a large user base, which in turn has led to an active
developer base and Matplotlib’s powerful tools and ubiquity within the scientific Python world.
最重要的特性之⼀就是它的对很多操作系统以及后端图形引擎的⼴泛⽀持。Matplotlib能在⼤量的图形引擎上⼯作并输出多种不同
的格式,这意味着你可以认为⽆论使⽤哪种操作系统输出哪种格式,它都能良好⼯作。这种特性为Matplotlib带来了⼤量的⽤⼾基础,从⽽
也吸引了活跃的开发者社区,使其发展称为在整个科学Python社区中⽆处不在的强⼤绘图⼯具。
Matplotlib
In recent years, however, the interface and style of Matplotlib have begun to show their age. Newer tools like ggplot and
ggvis in the R language, along with web visualization toolkits based on D3js and HTML5 canvas, often make Matplotlib
feel clunky and old-fashioned. Still, I'm of the opinion that we cannot ignore Matplotlib's strength as a well-tested, crossplatform graphics engine. Recent Matplotlib versions make it relatively easy to set new global plotting styles (see
Customizing Matplotlib: Configurations and Style Sheets), and people have been developing new packages that build on
its powerful internals to drive Matplotlib via cleaner, more modern APIs—for example, Seaborn (discussed in
Visualization With Seaborn), ggpy, HoloViews, Altair, and even Pandas itself can be used as wrappers around
Matplotlib's API. Even with wrappers like these, it is still often useful to dive into Matplotlib's syntax to adjust the final plot
output. For this reason, I believe that Matplotlib itself will remain a vital piece of the data visualization stack, even if new
tools mean the community gradually moves away from using the Matplotlib API directly.
然⽽最近⼏年,Matplotlib显得有点过时了。R语⾔中的ggplot和ggvis这些新⼯具⼴泛应⽤了类似D3js和HTML5画布这样的Web技术,让
Matplotlib显得相形⻅绌。近来的Matplotlib版本将设置新的图表⻛格变得相对简单了⼀些(参⻅⾃定义matplotlib:配置和样式单),⽽且
开发者在Matplotlib基础上开发了很多新的包,使得可视化过程能够通过更清晰和现代的API来实现,例如Seaborn(参⻅使⽤Seaborn进⾏
可视化)、ggpy、HoloViews和Altair,⽽且Pandas本⾝也提供了对Matplotlib的API封装。但是即使使⽤封装后的API,深⼊研究Matplotlib
的语法对于更精细的调整图表的输出也是⾮常有帮助的。正因为如此,作者深信Matplotlib仍然会在数据可视化技术栈中占有不可或缺的地
位,即使近期,社区已经逐步不再直接调⽤Matplotlib的API的情况下。
General Matplotlib Tips
通⽤提⽰
Before we dive into the details of creating visualizations with Matplotlib, there are a few useful things you should know
about using the package.
在我们开始深⼊介绍使⽤Matplotlib进⾏可视化之前,有⼀些使⽤该软件包的基本知识需要了解。
Importing Matplotlib
载⼊Matplotlib
Just as we use the np shorthand for NumPy and the pd shorthand for Pandas, we will use some standard shorthands
for Matplotlib imports:
的载⼊使⽤惯例别名 np ,Pandas的载⼊使⽤惯例别名 pd ,下⾯是载⼊Matplotlib的惯例别名:
NumPy
In [1]: import matplotlib as mpl
import matplotlib.pyplot as plt
The plt interface is what we will use most often, as we shall see throughout this chapter.
plt
是我们最常⽤的模块,本章中我们会⼀直看到它。
Setting Styles
设置⻛格
We will use the plt.style directive to choose appropriate aesthetic styles for our figures. Here we will set the
classic style, which ensures that the plots we create use the classic Matplotlib style:
使⽤ plt.style 属性⽤来给我们的图表设置视觉的⻛格。下⾯我们设置使⽤ classic ⻛格,这让我们之后的图表都会保持使⽤经典
Matplotlib⻛格:
In [2]: plt.style.use('classic')
Throughout this section, we will adjust this style as needed. Note that the stylesheets used here are supported as of
Matplotlib version 1.5; if you are using an earlier version of Matplotlib, only the default style is available. For more
information on stylesheets, see Customizing Matplotlib: Configurations and Style Sheets.
在本章中,我们会根据需要调整⻛格设置。这⾥要说明的是,只有Matplotlib 1.5及之后的版本⽀持⻛格设置;如果你在使⽤更早期的版
本,那么Matplotlib只能使⽤默认⻛格。要获取关于样式单的更多内容,参⻅⾃定义matplotlib:配置和样式单。
show() or No show() ? How to Display Your Plots
show()
或是不要 show() ?如何显⽰你的图表
A visualization you can't see won't be of much use, but just how you view your Matplotlib plots depends on the context.
The best use of Matplotlib differs depending on how you are using it; roughly, the three applicable contexts are using
Matplotlib in a script, in an IPython terminal, or in an IPython notebook.
⼀张你看不到的图表不会有什么⽤处,但是显⽰图表的⽅法根据使⽤环境会有所不同。Matplotlib的最佳实践取决于你在什么环境中使⽤
它;通常有三种应⽤场景,在脚本⽂件中使⽤,在IPython终端中使⽤以及在IPython notebook中使⽤。
Plotting from a script
在脚本⽂件中作图
If you are using Matplotlib from within a script, the function plt.show() is your friend. plt.show() starts an event
loop, looks for all currently active figure objects, and opens one or more interactive windows that display your figure or
figures.
如果你是在脚本⽂件中使⽤Matplotlib, plt.show() 是你显⽰图表的函数。 plt.show() 会启动⼀个事件循环,找到所有激活的图表
对象,然后打开⼀个或多个交互的窗⼝来显⽰你的图表。
So, for example, you may have a file called myplot.py containing the following:
因此,假设你有⼀个myplot.py⽂件包含以下代码:
# ------- file: myplot.py -----import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x))
plt.show()
You can then run this script from the command-line prompt, which will result in a window opening with your figure
displayed:
你可以在命令⾏中运⾏这个脚本⽂件,运⾏结果会打开⼀个窗⼝⾥⾯显⽰你设置的图表:
$ python myplot.py
The plt.show() command does a lot under the hood, as it must interact with your system's interactive graphical
backend. The details of this operation can vary greatly from system to system and even installation to installation, but
matplotlib does its best to hide all these details from you.
函数在底层做了许多⼯作,因为它需要和你系统的交互式图形引擎通信。这个操作的细节根据系统不同甚⾄不同安装⽅式
会有区别,Matplotlib尽最⼤可能为⽤⼾屏蔽了这些底层实现细节。
plt.show()
One thing to be aware of: the plt.show() command should be used only once per Python session, and is most often
seen at the very end of the script. Multiple show() commands can lead to unpredictable backend-dependent behavior,
and should mostly be avoided.
还要提醒⼀下: plt.show() 函数在每个Python会话中仅能使⽤⼀次,最常⻅的情况就是在脚本的末尾使⽤它。多次调⽤ show() 函数
会导致不可预料的结果,应该避免。
Plotting from an IPython shell
在IPython终端中作图
It can be very convenient to use Matplotlib interactively within an IPython shell (see IPython: Beyond Normal Python).
IPython is built to work well with Matplotlib if you specify Matplotlib mode. To enable this mode, you can use the
%matplotlib magic command after starting ipython :
在IPython终端(参⻅IPython:超越Python解释器)中交互式使⽤Matplotlib是⾮常⽅便的。IPython内建有⽀持Matplotlib的模式。要激活
这个模式,你只需要在IPython终端输⼊ %matplotlib 魔术指令即可:
In [1]: %matplotlib
Using matplotlib backend: TkAgg
In [2]: import matplotlib.pyplot as plt
At this point, any plt plot command will cause a figure window to open, and further commands can be run to update
the plot. Some changes (such as modifying properties of lines that are already drawn) will not draw automatically: to force
an update, use plt.draw() . Using plt.show() in Matplotlib mode is not required.
这之后任何 plt 的作图命令都会打开⼀个窗⼝包含作出的图表,后续运⾏的命令还能更新图表。某些改变(例如修改已经画好的线条的属
性)不会⾃动更新,这时可以使⽤ plt.draw() 来强制更新窗⼝。在Matplotlib模式下是不需要使⽤ plt.show() 的。
Plotting from an IPython notebook
在IPython notebook中作图
The IPython notebook is a browser-based interactive data analysis tool that can combine narrative, code, graphics,
HTML elements, and much more into a single executable document (see IPython: Beyond Normal Python).
是⼀个基于浏览器的交互式开发⼯具,能将说明、代码、图像、HTML和其他内容都合成在⼀个可执⾏⽂档中(参⻅
:超越
解释器)。
IPython notebook
IPython
Python
Plotting interactively within an IPython notebook can be done with the %matplotlib command, and works in a similar
way to the IPython shell. In the IPython notebook, you also have the option of embedding graphics directly in the
notebook, with two possible options:
%matplotlib notebook will lead to interactive plots embedded within the notebook
%matplotlib inline will lead to static images of your plot embedded in the notebook
在IPython notebook中交互式的作图也可以使⽤ %matplotlib 魔术指令,其⼯作⽅式类似于在IPython终端中⼀样。⽽且在IPython
notebook中,你还可以通过指定该魔术指令的参数让作出的图直接在内联在notebook中显⽰。两个参数可以指定不同的作图模式:
%matplotlib notebook :在notebook中作出具有交互控制功能的内联图表
%matplotlib inline : 在notebook中作出静态内联图表
For this book, we will generally opt for %matplotlib inline :
本书中,我们通常使⽤ %matplotlib inline :
In [3]: %matplotlib inline
After running this command (it needs to be done only once per kernel/session), any cell within the notebook that creates
a plot will embed a PNG image of the resulting graphic:
运⾏了这条魔术指令后(只需要在每个jupyter内核中运⾏⼀次即可),后续notebook中任何创建图表的代码都会输出⼀个内嵌的PNG图
像,作出图表:
In [4]: import numpy as np
x = np.linspace(0, 10, 100)
fig = plt.figure()
plt.plot(x, np.sin(x), '-')
plt.plot(x, np.cos(x), '--');
Saving Figures to File
将图表保存到⽂件
One nice feature of Matplotlib is the ability to save figures in a wide variety of formats. Saving a figure can be done using
the savefig() command. For example, to save the previous figure as a PNG file, you can run this:
还有⼀个⾮常棒的功能,那就是将图表保存成很多种不同的⽂件格式。保存图表可以通过 savefig() 函数实现。例如,如果我
们需要将上⾯的图表保存成⼀个PNG⽂件,只需要执⾏下⾯的代码:
Matplotlib
In [5]: fig.savefig('my_figure.png')
We now have a file called my_figure.png in the current working directory:
然后在当前⼯作⽬录下就可以看到这个⽂件:
In [6]: !ls -lh my_figure.png
⽉ 6 15:47 my_figure.png
-rw-rw-r-- 1 wangy wangy 26K 12
To confirm that it contains what we think it contains, let's use the IPython Image object to display the contents of this
file:
我们可以使⽤IPython的 Image 对象将这个图像⽂件显⽰出来,验证⼀下保存的⽂件是否和图表⼀致:
In [7]: from IPython.display import Image
Image('my_figure.png')
Out[7]:
In savefig() , the file format is inferred from the extension of the given filename. Depending on what backends you
have installed, many different file formats are available. The list of supported file types can be found for your system by
using the following method of the figure canvas object:
在 savefig() 函数中,保存⽂件的格式是通过⽂件的扩展名决定的。根据系统的图形引擎不同,⽀持的⽂件格式略有不同。你可以通过
下⾯的代码列出系统⽀持的所有⽂件格式:
In [8]: fig.canvas.get_supported_filetypes()
Out[8]: {'ps': 'Postscript',
'eps': 'Encapsulated Postscript',
'pdf': 'Portable Document Format',
'pgf': 'PGF code for LaTeX',
'png': 'Portable Network Graphics',
'raw': 'Raw RGBA bitmap',
'rgba': 'Raw RGBA bitmap',
'svg': 'Scalable Vector Graphics',
'svgz': 'Scalable Vector Graphics',
'jpg': 'Joint Photographic Experts Group',
'jpeg': 'Joint Photographic Experts Group',
'tif': 'Tagged Image File Format',
'tiff': 'Tagged Image File Format'}
Note that when saving your figure, it's not necessary to use plt.show() or related commands discussed earlier.
注意当你保存图表的时候,是不需要使⽤ plt.show() 或类似的显⽰图表命令的。
Two Interfaces for the Price of One
两套不同接⼝
A potentially confusing feature of Matplotlib is its dual interfaces: a convenient MATLAB-style state-based interface, and a
more powerful object-oriented interface. We'll quickly highlight the differences between the two here.
⼀个很令⼈迷惑的地⽅是它具有两套接⼝:⼀套是很⽅便的MATLAB⻛格的接⼝,还有⼀套是更强⼤的⾯向对象的接⼝。我们在
这⾥简单的介绍⼀下它们的区别。
Matplotlib
MATLAB-style Interface
MATLAB
⻛格接⼝
Matplotlib was originally written as a Python alternative for MATLAB users, and much of its syntax reflects that fact. The
MATLAB-style tools are contained in the pyplot ( plt ) interface. For example, the following code will probably look quite
familiar to MATLAB users:
最早是⽤来为MATLAB⽤⼾提供⼀套Python环境下的替代品,因此很多的语法反映了这⼀点。MATLAB⻛格的接⼝(函数)都封
( plt )模块中。例如,下⾯的代码对于MATLAB⽤⼾来说不会陌⽣:
Matplotlib
pyplot
装在
创建图表
# 创建上⾯第⼀⾏的⼦图表,并设置x,y轴的数据
plt.subplot(2, 1, 1) # (⾏、列、⼦图表序号)
In [9]: plt.figure()
#
plt.plot(x, np.sin(x))
创建下⾯第⼆⾏的⼦图表,并设置x,y轴的数据
#
plt.subplot(2, 1, 2)
plt.plot(x, np.cos(x));
It is important to note that this interface is stateful: it keeps track of the "current" figure and axes, which are where all
plt commands are applied. You can get a reference to these using the plt.gcf() (get current figure) and
plt.gca() (get current axes) routines.
需要提醒的是这套接⼝是有状态的:它会持续保持着当前的图表和维度,⽆论那个 plt 命令改变了它。你可以通过 glt.gcf() 和
plt.gca() 函数获得图表和维度的引⽤。
While this stateful interface is fast and convenient for simple plots, it is easy to run into problems. For example, once the
second panel is created, how can we go back and add something to the first? This is possible within the MATLAB-style
interface, but a bit clunky. Fortunately, there is a better way.
虽然这种有状态的接⼝在创建简单图表时⾮常快速和⽅便,但是也存在问题。例如,当第⼆个⼦图表创建了之后,我们如何回去在第⼀个
⼦图表中增加内容呢?虽然在MATLAB⻛格接⼝中也可以实现,但是⾮常别扭。幸运的是,我们有更好的办法。
Object-oriented interface
⾯向对象接⼝
The object-oriented interface is available for these more complicated situations, and for when you want more control over
your figure. Rather than depending on some notion of an "active" figure or axes, in the object-oriented interface the
plotting functions are methods of explicit Figure and Axes objects. To re-create the previous plot using this style of
plotting, you might do the following:
⾯向对象接⼝更适合于这种复杂的场景和你需要对图表更多控制权的场景。与其依赖于“当前活跃的”图表和维度,⾯向对象接⼝提供的是
具体 Figure 和 Axes 对象的⽅法。采⽤这种接⼝重新创建上⾯的图表的代码如下:
⾸先创建两个⼦图表
返回值 是两个⼦
In [10]: #
#
fig
Figure
fig, ax = plt.subplots(2)
在两个不同的
对象,ax是两个⼦Axes对象
对象上调⽤
#
Axes
plot
ax[0].plot(x, np.sin(x))
ax[1].plot(x, np.cos(x));
⽅法分别作图
For more simple plots, the choice of which style to use is largely a matter of preference, but the object-oriented approach
can become a necessity as plots become more complicated. Throughout this chapter, we will switch between the
MATLAB-style and object-oriented interfaces, depending on what is most convenient. In most cases, the difference is as
small as switching plt.plot() to ax.plot() , but there are a few gotchas that we will highlight as they come up in
the following sections.
对于简单的图表来说,采⽤哪种⻛格的接⼝作图取决于个⼈喜好,但是对于复杂的图表来说,⾯向对象接⼝是必须的。在本章中,我们会
在两者之间进⾏切换,依据哪样⽅式更加⽅便来决定。在⼤多数情况下, plt.plot() 切换到 ax.plot() 之间的区别会很⼩,但是⾥
⾯会有些坑,我们会在后续⼩节中着重指出。
<
更多资源 | ⽬录 | 简单的折线图 >
Open in Colab
<
使⽤matplotlib展⽰数据 | ⽬录 | 简单的散点图 >
Simple Line Plots
简单的折线图
Perhaps the simplest of all plots is the visualization of a single function $y = f(x)$. Here we will take a first look at creating
a simple plot of this type. As with all the following sections, we'll start by setting up the notebook for plotting and importing
the packages we will use:
对于图表来说,最简单的莫过于作出⼀个单⼀函数$y=f(x)$的图像。本节中我们⾸先来介绍创建这种类型图表。本节和后续⼩节中,我们
都会使⽤下⾯的代码将我们需要的包载⼊到notebook中:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
For all Matplotlib plots, we start by creating a figure and an axes. In their simplest form, a figure and axes can be created
as follows:
对于所有的Matplotlib图表来说,我们都需要从创建图形和维度开始。图形和维度可以使⽤下⾯代码进⾏最简形式的创建:
In [2]: fig = plt.figure()
ax = plt.axes()
In Matplotlib, the figure (an instance of the class plt.Figure ) can be thought of as a single container that contains all
the objects representing axes, graphics, text, and labels. The axes (an instance of the class plt.Axes ) is what we see
above: a bounding box with ticks and labels, which will eventually contain the plot elements that make up our
visualization. Throughout this book, we'll commonly use the variable name fig to refer to a figure instance, and ax to
refer to an axes instance or group of axes instances.
在Matplotlib中,图形(类 plt.Figure 的⼀个实例)可以被认为是⼀个包括所有维度、图像、⽂本和标签对象的容器。维度(类
plt.Axes 的⼀个实例)就是你上⾯看到的图像,⼀个有边界的格⼦包括刻度和标签,最终还有我们画在上⾯的图表元素。在本书中,我
们会使⽤变量名 fig 来指代图形对象,以及变量名 ax 来指代维度变量。
Once we have created an axes, we can use the ax.plot function to plot some data. Let's start with a simple sinusoid:
⼀旦我们创建了维度,我们可以使⽤ ax.plot ⽅法将数据绘制在图表上。下⾯是⼀个简单的正弦函数图形:
In [3]: fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x));
Alternatively, we can use the pylab interface and let the figure and axes be created for us in the background (see Two
Interfaces for the Price of One for a discussion of these two interfaces):
同样的,我们可以使⽤pylab接⼝(MATLAB⻛格的接⼝)帮我们在后台⾃动创建这两个对象:
In [4]: plt.plot(x, np.sin(x));
If we want to create a single figure with multiple lines, we can simply call the plot function multiple times:
如果我们需要在同⼀幅图形中绘制多根线条,只需要多次调⽤ plot 函数即可:
In [5]: plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x));
That's all there is to plotting simple functions in Matplotlib! We'll now dive into some more details about how to control the
appearance of the axes and lines.
这就是在Matplotlib中绘制简单函数图像的所有接⼝了。下⾯我们深⼊了解⼀下控制坐标轴和线条外观的细节。
Adjusting the Plot: Line Colors and Styles
调整折线图:线条颜⾊和⻛格
The first adjustment you might wish to make to a plot is to control the line colors and styles. The plt.plot() function
takes additional arguments that can be used to specify these. To adjust the color, you can use the color keyword,
which accepts a string argument representing virtually any imaginable color. The color can be specified in a variety of
ways:
你可能第⼀个想到需要进⾏调整的部分就是线条的颜⾊和⻛格。 plt.plot() 函数接受额外的参数可以⽤来指定它们。通过指定 color
关键字参数可以调整颜⾊,这个字符串类型参数基本上能⽤来代表任何你能想到的颜⾊。可以通过多种⽅式指定颜⾊参数:
译者注:所有HTML颜⾊名称可以在这⾥找到。
In [6]: plt.plot(x, np.sin(x - 0), color='blue')
# 通过颜⾊名称指定
plt.plot(x, np.sin(x - 1), color='g')
# 通过颜⾊简写名称指定(rgbcmyk)
plt.plot(x, np.sin(x - 2), color='0.75')
# 介于0-1之间的灰阶值
plt.plot(x, np.sin(x - 3), color='#FFDD44')
# 16进制的RRGGBB值
plt.plot(x, np.sin(x - 4), color=(1.0,0.2,0.3)) # RGB元组的颜⾊值,每个值介于0-1
plt.plot(x, np.sin(x - 5), color='chartreuse'); # 能⽀持所有HTML颜⾊名称值
If no color is specified, Matplotlib will automatically cycle through a set of default colors for multiple lines.
如果没有指定颜⾊,Matplotlib会在⼀组默认颜⾊值中循环使⽤来绘制每⼀条线条。
Similarly, the line style can be adjusted using the linestyle keyword:
类似的,通过 linestyle 关键字参数可以指定线条的⻛格:
In [7]: plt.plot(x, x + 0, linestyle='solid')
plt.plot(x, x + 1, linestyle='dashed')
plt.plot(x, x + 2, linestyle='dashdot')
plt.plot(x, x + 3, linestyle='dotted');
还可以⽤形象的符号代表线条⻛格
#
plt.plot(x, x + 4, linestyle='-') #
plt.plot(x, x + 5, linestyle='--') #
plt.plot(x, x + 6, linestyle='-.') #
plt.plot(x, x + 7, linestyle=':'); #
实线
虚线
⻓短点虚线
点线
If you would like to be extremely terse, these linestyle and color codes can be combined into a single nonkeyword argument to the plt.plot() function:
如果你喜欢更简洁的代码,这些 linestyle 和 color 参数能够合并成⼀个⾮关键字参数,传递给 plt.plot() 函数:
绿⾊实线
天⻘⾊虚线
⿊⾊⻓短点虚线
红⾊点线
In [8]: plt.plot(x, x + 0, '-g') #
plt.plot(x, x + 1, '--c') #
plt.plot(x, x + 2, '-.k') #
plt.plot(x, x + 3, ':r'); #
These single-character color codes reflect the standard abbreviations in the RGB (Red/Green/Blue) and CMYK
(Cyan/Magenta/Yellow/blacK) color systems, commonly used for digital color graphics.
上⾯的单字⺟颜⾊码是RGB颜⾊系统以及CMYK颜⾊系统的缩写,被⼴泛应⽤在数字化图像的颜⾊系统中。
There are many other keyword arguments that can be used to fine-tune the appearance of the plot; for more details, I'd
suggest viewing the docstring of the plt.plot() function using IPython's help tools (See Help and Documentation in
IPython).
还有很多其他的关键字参数可以对折线图的外观进⾏精细调整;可以通过在IPython中使⽤帮助⼯具(参⻅IPython的帮助和⽂档)查看
plt.plot() 函数的⽂档来获得更多细节内容。
Adjusting the Plot: Axes Limits
调整折线图:坐标轴范围
Matplotlib does a decent job of choosing default axes limits for your plot, but sometimes it's nice to have finer control. The
most basic way to adjust axis limits is to use the plt.xlim() and plt.ylim() methods:
会⾃动选择⾮常合适的坐标轴范围来绘制你的图像,但是有些情况下你也需要⾃⼰进⾏相关调整。使⽤ plt.xlim() 和
函数可以调整坐标轴的范围:
Matplotlib
plt.ylim()
In [9]: plt.plot(x, np.sin(x))
plt.xlim(-1, 11)
plt.ylim(-1.5, 1.5);
If for some reason you'd like either axis to be displayed in reverse, you can simply reverse the order of the arguments:
如果某些情况下你希望将坐标轴反向,你可以通过上⾯的函数实现,将参数顺序颠倒即可:
In [10]: plt.plot(x, np.sin(x))
plt.xlim(10, 0)
plt.ylim(1.2, -1.2);
A useful related method is plt.axis() (note here the potential confusion between axes with an e, and axis with an i).
The plt.axis() method allows you to set the x and y limits with a single call, by passing a list which specifies
[xmin, xmax, ymin, ymax] :
相关的函数还有 plt.axis() (注意:这不是 plt.axes() 函数,函数名称是i⽽不是e)。这个函数可以在⼀个函数调⽤中就完成x轴
和y轴范围的设置,传递⼀个 [xmin, xmax, ymin, ymax] 的列表参数即可:
In [11]: plt.plot(x, np.sin(x))
plt.axis([-1, 11, -1.5, 1.5]);
The plt.axis() method goes even beyond this, allowing you to do things like automatically tighten the bounds
around the current plot:
当然 plt.axis() 函数不仅能设置范围,还能像下⾯代码⼀样将坐标轴压缩到刚好⾜够绘制折线图像的⼤⼩:
In [12]: plt.plot(x, np.sin(x))
plt.axis('tight');
It allows even higher-level specifications, such as ensuring an equal aspect ratio so that on your screen, one unit in x is
equal to one unit in y :
还可以通过设置 'equal' 参数设置 x 轴与 y 轴使⽤相同的⻓度单位:
In [13]: plt.plot(x, np.sin(x))
plt.axis('equal');
For more information on axis limits and the other capabilities of the plt.axis method, refer to the plt.axis
docstring.
更多关于设置axis属性的内容请查阅 plt.axis 函数的⽂档字符串。
Labeling Plots
折线图标签
As the last piece of this section, we'll briefly look at the labeling of plots: titles, axis labels, and simple legends.
本节最后介绍⼀下在折线图上绘制标签:标题、坐标轴标签和简单的图例。
Titles and axis labels are the simplest such labels—there are methods that can be used to quickly set them:
标题和坐标轴标签是最简单的这类标签,Matplotlib提供了函数⽤来⽅便的设置它们:
In [14]: plt.plot(x, np.sin(x))
plt.title("A Sine Curve")
plt.xlabel("x")
plt.ylabel("sin(x)");
The position, size, and style of these labels can be adjusted using optional arguments to the function. For more
information, see the Matplotlib documentation and the docstrings of each of these functions.
这些标签的位置、⼤⼩和⻛格可以通过上⾯函数的可选参数进⾏设置。参阅Matplotlib在线⽂档和这些函数的⽂档字符串可以获得更多的信
息。
When multiple lines are being shown within a single axes, it can be useful to create a plot legend that labels each line
type. Again, Matplotlib has a built-in way of quickly creating such a legend. It is done via the (you guessed it)
plt.legend() method. Though there are several valid ways of using this, I find it easiest to specify the label of each
line using the label keyword of the plot function:
当⼀幅图中绘制了多条折线时,如果能够绘制⼀个线条对应的图例能让图表更加清晰。Matplotlib也内建了函数来快速创建图例。估计你也
猜到了,通过 plt.legend() 函数可以实现这个需求。虽然有很多种正确的⽅法来指定图例,作者认为最简单的⽅法是通过在绘制每条
线条时指定对应的 label 关键字参数来使⽤这个函数:
In [15]: plt.plot(x, np.sin(x), '-g', label='sin(x)')
plt.plot(x, np.cos(x), ':b', label='cos(x)')
plt.axis('equal')
plt.legend();
As you can see, the plt.legend() function keeps track of the line style and color, and matches these with the correct
label. More information on specifying and formatting plot legends can be found in the plt.legend docstring;
additionally, we will cover some more advanced legend options in Customizing Plot Legends.
上图可⻅, plt.legend() 函数绘制的图例线条与图中的折线⽆论⻛格和颜⾊都保持⼀致。查阅 plt.legend ⽂档字符串可以获得更多
相关信息;我们在⾃定义图表图例⼀节中也会讨论更⾼级的图例应⽤。
Aside: Matplotlib Gotchas
额外内容:Matplotlib的坑
While most plt functions translate directly to ax methods (such as plt.plot() → ax.plot() ,
plt.legend() → ax.legend() , etc.), this is not the case for all commands. In particular, functions to set limits,
labels, and titles are slightly modified. For transitioning between MATLAB-style functions and object-oriented methods,
make the following changes:
虽然⼤多数的 plt 函数都可以直接转换为 ax 的⽅法进⾏调⽤(例如 plt.plot() → ax.plot() , plt.legend() →
ax.legend() 等),但是并不是所有的命令都能应⽤这种情况。特别是⽤于设置极值、标签和标题的函数都有⼀定的改变。下表列出了
将MATLAB⻛格的函数转换为⾯向对象的⽅法的区别:
plt.xlabel() → ax.set_xlabel()
plt.ylabel() → ax.set_ylabel()
plt.xlim() → ax.set_xlim()
plt.ylim() → ax.set_ylim()
plt.title() → ax.set_title()
In the object-oriented interface to plotting, rather than calling these functions individually, it is often more convenient to
use the ax.set() method to set all these properties at once:
在⾯向对象接⼝中,与其逐个调⽤上⾯的⽅法来设置属性,更常⻅的使⽤ ax.set() ⽅法来⼀次性设置所有的属性:
In [16]: ax = plt.axes()
ax.plot(x, np.sin(x))
ax.set(xlim=(0, 10), ylim=(-2, 2),
xlabel='x', ylabel='sin(x)',
title='A Simple Plot');
<
使⽤matplotlib展⽰数据 | ⽬录 | 简单的散点图 >
<
简单的折线图 | ⽬录 | 误差可视化 >
Open in Colab
Simple Scatter Plots
简单散点图
Another commonly used plot type is the simple scatter plot, a close cousin of the line plot. Instead of points being joined
by line segments, here the points are represented individually with a dot, circle, or other shape. We’ll start by setting up
the notebook for plotting and importing the functions we will use:
另⼀种常⽤的图表类型是简单散点图,它是折线图的近亲。不像折线图,图中的点连接起来组成连线,散点图中的点都是独⽴分布的点
状、圆圈或其他形状。本节开始我们也是⾸先将需要⽤到的图表⼯具和函数导⼊到notebook中:
In [2]: %matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
Scatter Plots with plt.plot
使⽤ plt.plot 绘制散点图
In the previous section we looked at plt.plot / ax.plot to produce line plots. It turns out that this same function can
produce scatter plots as well:
在上⼀节中,我们介绍了 plt.plot / ax.plot ⽅法绘制折线图。这两个⽅法也可以同样⽤来绘制散点图:
In [3]: x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o', color='black');
The third argument in the function call is a character that represents the type of symbol used for the plotting. Just as you
can specify options such as '-' , '--' to control the line style, the marker style has its own set of short string codes.
The full list of available symbols can be seen in the documentation of plt.plot , or in Matplotlib's online
documentation. Most of the possibilities are fairly intuitive, and we'll show a number of the more common ones here:
传递给函数的第三个参数是使⽤⼀个字符代表的图表绘制点的类型。就像你可以使⽤ '-' 或 '--' 来控制线条的⻛格那样,点的类型⻛格
也可以使⽤短字符串代码来表⽰。所有可⽤的符号可以通过 plt.plot ⽂档或Matplotlib在线⽂档进⾏查阅。⼤多数的代码都是⾮常直观
的,我们使⽤下⾯的例⼦可以展⽰那些最通⽤的符号:
In [4]: rng = np.random.RandomState(0)
for marker in ['o', '.', ',', 'x', '+', 'v', '^', '<', '>', 's', 'd']:
plt.plot(rng.rand(5), rng.rand(5), marker,
label="marker='{0}'".format(marker))
plt.legend(numpoints=1)
plt.xlim(0, 1.8);
For even more possibilities, these character codes can be used together with line and color codes to plot points along
with a line connecting them:
⽽且这些符号代码可以和线条、颜⾊代码⼀起使⽤,这会在折线图的基础上绘制出散点:
In [5]: plt.plot(x, y, '-ok');
Additional keyword arguments to plt.plot specify a wide range of properties of the lines and markers:
plt.plot
还有很多额外的关键字参数⽤来指定⼴泛的线条和点的属性:
In [6]: plt.plot(x, y, '-p', color='gray',
markersize=15, linewidth=4,
markerfacecolor='white',
markeredgecolor='gray',
markeredgewidth=2)
plt.ylim(-1.2, 1.2);
This type of flexibility in the plt.plot function allows for a wide variety of possible visualization options. For a full
description of the options available, refer to the plt.plot documentation.
plt.plot
函数的这种灵活性提供了很多的可视化选择。查阅 plt.plot 帮助⽂档获得完整的选项说明。
Scatter Plots with plt.scatter
使⽤ plt.scatter 绘制散点图
A second, more powerful method of creating scatter plots is the plt.scatter function, which can be used very
similarly to the plt.plot function:
第⼆种更强⼤的绘制散点图的⽅法是使⽤ plt.scatter 函数,它的使⽤⽅法和 plt.plot 类似:
In [7]: plt.scatter(x, y, marker='o');
The primary difference of plt.scatter from plt.plot is that it can be used to create scatter plots where the
properties of each individual point (size, face color, edge color, etc.) can be individually controlled or mapped to data.
和 plt.plot 的主要区别在于, plt.scatter 可以针对每个点设置不同属性(⼤⼩、填充颜⾊、边缘颜⾊等),还可
以通过数据集合对这些属性进⾏设置。
plt.scatter
Let's show this by creating a random scatter plot with points of many colors and sizes. In order to better see the
overlapping results, we'll also use the alpha keyword to adjust the transparency level:
让我们通过⼀个随机值数据集绘制不同颜⾊和⼤⼩的散点图来说明。为了更好的查看重叠的结果,我们还使⽤了 alpha 关键字参数对点
的透明度进⾏了调整:
In [8]: rng = np.random.RandomState(0)
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)
plt.scatter(x, y, c=colors, s=sizes, alpha=0.3,
cmap='viridis')
plt.colorbar(); #
显⽰颜⾊对⽐条
Notice that the color argument is automatically mapped to a color scale (shown here by the colorbar() command),
and that the size argument is given in pixels. In this way, the color and size of points can be used to convey information in
the visualization, in order to visualize multidimensional data.
注意图表右边有⼀个颜⾊对⽐条(这⾥通过 colormap() 函数输出),图表中的点⼤⼩的单位是像素。使⽤这种⽅法,散点的颜⾊和⼤
⼩都能⽤来展⽰数据信息,在希望展⽰多个维度数据集合的情况下很直观。
For example, we might use the Iris data from Scikit-Learn, where each sample is one of three types of flowers that has
had the size of its petals and sepals carefully measured:
例如,当我们使⽤Scikit-learn中的鸢尾花数据集,⾥⾯的每个样本都是三种鸢尾花中的其中⼀种,并带有仔细测量的花瓣和花萼的尺⼨数
据:
In [9]: from sklearn.datasets import load_iris
iris = load_iris()
features = iris.data.T
plt.scatter(features[0], features[1], alpha=0.2,
s=100*features[3], c=iris.target, cmap='viridis')
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1]);
We can see that this scatter plot has given us the ability to simultaneously explore four different dimensions of the data:
the (x, y) location of each point corresponds to the sepal length and width, the size of the point is related to the petal
width, and the color is related to the particular species of flower. Multicolor and multifeature scatter plots like this can be
useful for both exploration and presentation of data.
我们可以从上图中看出,可以通过散点图同时展⽰该数据集的四个不同维度:图中的(x, y)位置代表每个样本的花萼的⻓度和宽度,散点的
⼤⼩代表每个样本的花瓣的宽度,⽽散点的颜⾊代表⼀种特定的鸢尾花类型。如上图的多种颜⾊和多种属性的散点图对于我们分析和展⽰
数据集时都⾮常有帮助。
plot Versus scatter : A Note on Efficiency
plot
和 scatter 对⽐:性能提醒
Aside from the different features available in plt.plot and plt.scatter , why might you choose to use one over
the other? While it doesn't matter as much for small amounts of data, as datasets get larger than a few thousand points,
plt.plot can be noticeably more efficient than plt.scatter . The reason is that plt.scatter has the
capability to render a different size and/or color for each point, so the renderer must do the extra work of constructing
each point individually. In plt.plot , on the other hand, the points are always essentially clones of each other, so the
work of determining the appearance of the points is done only once for the entire set of data. For large datasets, the
difference between these two can lead to vastly different performance, and for this reason, plt.plot should be
preferred over plt.scatter for large datasets.
除了上⾯说的 plt.plot 和 plt.scatter 对于每个散点不同属性的⽀持不同之外,还有别的因素影响对这两个函数的选择吗?对于⼩
的数据集来说,两者并⽆差别,当数据集增⻓到⼏千个点时, plt.plot 会明显⽐ plt.scatter 的性能要⾼。造成这个差异的原因是
plt.scatter ⽀持每个点使⽤不同的⼤⼩和颜⾊,因此渲染每个点时需要完成更多额外的⼯作。⽽ plt.plot 来说,每个点都是简单
的复制另⼀个点产⽣,因此对于整个数据集来说,确定每个点的展⽰属性的⼯作仅需要进⾏⼀次即可。对于很⼤的数据集来说,这个差异
会导致两者性能的巨⼤区别,因此,对于⼤数据集应该优先使⽤ plt.plot 函数。
<
简单的折线图 | ⽬录 | 误差可视化 >
Open in Colab
<
简单的散点图 | ⽬录 | 密度和轮廓图 >
Open in Colab
Visualizing Errors
误差可视化
For any scientific measurement, accurate accounting for errors is nearly as important, if not more important, than
accurate reporting of the number itself. For example, imagine that I am using some astrophysical observations to
estimate the Hubble Constant, the local measurement of the expansion rate of the Universe. I know that the current
literature suggests a value of around 71 (km/s)/Mpc, and I measure a value of 74 (km/s)/Mpc with my method. Are the
values consistent? The only correct answer, given this information, is this: there is no way to know.
对于任何的科学测量来说,精确计算误差与精确报告测量值基本上同等重要,如果不是更加重要的话。例如,设想我正在使⽤⼀些天⽂物
理学观测值来估算哈勃常数,即本地观测的宇宙膨胀系数。我从⼀些⽂献中知道这个值⼤概是71 (km/s)/Mpc,⽽我测量得到的值是74
(km/s)/Mpc,。这两个值是否⼀致?在仅给定这些数据的情况下,这个问题的答案是,⽆法回答。
译者注:Mpc(百万秒差距)参⻅秒差距
Suppose I augment this information with reported uncertainties: the current literature suggests a value of around 71 ±
2.5 (km/s)/Mpc, and my method has measured a value of 74 ± 5 (km/s)/Mpc. Now are the values consistent? That is a
question that can be quantitatively answered.
如果我们将信息增加⼀些,给出不确定性:最新的⽂献表⽰哈勃常数的值⼤约是71
这两个值是⼀致的吗?这就是⼀个可以准确回答的问题了。
±
2.5 (km/s)/Mpc
,我的测量值是74
±
5 (km/s)/Mpc
。
In visualization of data and results, showing these errors effectively can make a plot convey much more complete
information.
在数据和结果的可视化中,有效地展⽰这些误差能使你的图表涵盖和提供更加完整的信息。
Basic Errorbars
基础误差条
A basic errorbar can be created with a single Matplotlib function call:
调⽤⼀个Matplotlib函数就能创建⼀个基础的误差条:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
In [2]: x = np.linspace(0, 10, 50)
dy = 0.8
y = np.sin(x) + dy * np.random.randn(50)
plt.errorbar(x, y, yerr=dy, fmt='.k');
Here the fmt is a format code controlling the appearance of lines and points, and has the same syntax as the
shorthand used in plt.plot , outlined in Simple Line Plots and Simple Scatter Plots.
这⾥的 fmt 参数是⽤来控制线条和点⻛格的代码,与 plt.plot 有着相同的语法,参⻅简单的折线图和简单的散点图。
In addition to these basic options, the errorbar function has many options to fine-tune the outputs. Using these
additional options you can easily customize the aesthetics of your errorbar plot. I often find it helpful, especially in
crowded plots, to make the errorbars lighter than the points themselves:
除了上⾯的基本参数, errorbar 函数还有很多参数可以⽤来精细调节图表输出。使⽤这些参数你可以很容易的个性化调整误差条的样
式。作者发现通常将误差线条颜⾊调整为浅⾊会更加清晰,特别是在数据点⽐较密集的情况下:
In [3]: plt.errorbar(x, y, yerr=dy, fmt='o', color='black',
ecolor='lightgray', elinewidth=3, capsize=0);
In addition to these options, you can also specify horizontal errorbars ( xerr ), one-sided errorbars, and many other
variants. For more information on the options available, refer to the docstring of plt.errorbar .
除了上⾯介绍的参数,你还可以指定⽔平⽅向的误差条( xerr ),单边误差条和其他很多的参数。参阅 plt.errorbar 的帮助⽂档获
得更多信息。
Continuous Errors
连续误差
In some situations it is desirable to show errorbars on continuous quantities. Though Matplotlib does not have a built-in
convenience routine for this type of application, it's relatively easy to combine primitives like plt.plot and
plt.fill_between for a useful result.
在某些情况下可能需要对连续值展⽰误差条。虽然Matplotlib没有內建的函数能直接完成这个任务,但是你可以通过简单将 plt.plot 和
plt.fill_between 函数结合起来达到⽬标。
Here we'll perform a simple Gaussian process regression, using the Scikit-Learn API (see Introducing Scikit-Learn for
details). This is a method of fitting a very flexible non-parametric function to data with a continuous measure of the
uncertainty. We won't delve into the details of Gaussian process regression at this point, but will focus instead on how
you might visualize such a continuous error measurement:
这⾥我们会采⽤简单的⾼斯过程回归⽅法,Scikit-Learn提供了API(参⻅Scikit-Learn介绍)。这个⽅法⾮常适合在⾮参数化的函数中获得
连续误差。我们在这⾥不会详细介绍⾼斯过程回归,仅仅聚焦在如何绘制连续误差本⾝:
译者注:新版的sklearn修改了⾼斯过程回归实现⽅法,下⾯代码做了相应修改。
In [12]: from sklearn.gaussian_process import GaussianProcessRegressor
定义模型和⼀些符合模型的点
#
model = lambda x: x * np.sin(x)
xdata = np.array([1, 3, 5, 6, 8])
ydata = model(xdata)
计算⾼斯过程回归,使其符合
数据点
#
fit
gp = GaussianProcessRegressor()
gp.fit(xdata[:, np.newaxis], ydata)
xfit = np.linspace(0, 10, 1000)
yfit, std = gp.predict(xfit[:, np.newaxis], return_std=True)
dyfit = 2 * std #
sigma ~ 95%
两倍
确定区域
We now have xfit , yfit , and dyfit , which sample the continuous fit to our data. We could pass these to the
plt.errorbar function as above, but we don't really want to plot 1,000 points with 1,000 errorbars. Instead, we can
use the plt.fill_between function with a light color to visualize this continuous error:
我们现在有了 xfit 、 yfit 和 dyfit ,作为对我们数据的连续拟合值以及误差限。当然我们也可以像上⾯⼀样使⽤ plt.errorbar
绘制误差条,但是事实上我们不希望在图标上绘制1000个点的误差条。于是我们可以使⽤ plt.fill_between 函数在误差限区域内填充
⼀道浅⾊的误差带来展⽰连续误差:
可视化结果
In [13]: #
plt.plot(xdata, ydata, 'or')
plt.plot(xfit, yfit, '-', color='gray')
plt.fill_between(xfit, yfit - dyfit, yfit + dyfit,
color='gray', alpha=0.2)
plt.xlim(0, 10);
Note what we've done here with the fill_between function: we pass an x value, then the lower y-bound, then the
upper y-bound, and the result is that the area between these regions is filled.
注意上⾯我们调⽤ fill_between 函数:我们传递了的参数包括x值,y值的低限,然后是y值的⾼限,结果是图表中介于低限和⾼限之间
的区域会被填充。
The resulting figure gives a very intuitive view into what the Gaussian process regression algorithm is doing: in regions
near a measured data point, the model is strongly constrained and this is reflected in the small model errors. In regions
far from a measured data point, the model is not strongly constrained, and the model errors increase.
上图为我们提供了⼀个⾮常直观的⾼斯过程回归展⽰:在观测点的附近,模型会被限制在⼀个很⼩的区域内,反映了这些数据的误差⽐较
⼩。在远离观测点的区域,模型开始发散,反映了这时的数据误差⽐较⼤。
For more information on the options available in plt.fill_between() (and the closely related plt.fill()
function), see the function docstring or the Matplotlib documentation.
如果需要获得 plt.fill_between (以及类似的 plt.fill 函数)更多参数的信息,请查阅函数的帮助⽂档或Matplotlib在线⽂档。
Finally, if this seems a bit too low level for your taste, refer to Visualization With Seaborn, where we discuss the Seaborn
package, which has a more streamlined API for visualizing this type of continuous errorbar.
最后,如果你觉得本节的内容过于浅显,请参考使⽤Seaborn进⾏可视化,该⼩节会讨论Seaborn包,提供了将这种类型连续错误条进⾏可
视化的流式API。
<
简单的散点图 | ⽬录 | 密度和轮廓图 >
Open in Colab
<
误差可视化 | ⽬录 | 直⽅图, 分桶和密度 >
Open in Colab
Density and Contour Plots
密度和轮廓图
Sometimes it is useful to display three-dimensional data in two dimensions using contours or color-coded regions. There
are three Matplotlib functions that can be helpful for this task: plt.contour for contour plots, plt.contourf for
filled contour plots, and plt.imshow for showing images. This section looks at several examples of using these. We'll
start by setting up the notebook for plotting and importing the functions we will use:
有些情况下,我们需要在⼆维图表中使⽤轮廓或颜⾊区域来展⽰三维的数据(可以设想等⾼线地图或温度分布图)。Matplotlib提供了三个
有⽤的函数来处理这项任务: plt.contour 绘制轮廓图, plt.contourf 来绘制填充区域颜⾊的图表以及 plt.imshow 来展⽰图
像。本节会介绍⼏个使⽤它们的例⼦。当然我们还是⾸先从将需要使⽤的包导⼊notebook和初始化⼯作开始:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import numpy as np
Visualizing a Three-Dimensional Function
三维可视化函数
We'll start by demonstrating a contour plot using a function z = f(x, y), using the following particular choice for f (we've
seen this before in Computation on Arrays: Broadcasting, when we used it as a motivating example for array
broadcasting):
我们⾸先使⽤⼀个简单的函数
我们⽤来作为数组⼴播运算的例⼦:
z = f(x, y)
绘制⼀个轮廓图来进⾏说明,下⾯的这个函数我们在在数组上计算:⼴播⼀节中已经⻅过,那⾥
In [2]: def f(x, y):
return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
A contour plot can be created with the plt.contour function. It takes three arguments: a grid of x values, a grid of y
values, and a grid of z values. The x and y values represent positions on the plot, and the z values will be represented by
the contour levels. Perhaps the most straightforward way to prepare such data is to use the np.meshgrid function,
which builds two-dimensional grids from one-dimensional arrays:
轮廓图可以使⽤ plt.contour 函数进⾏创建。它接收三个参数:x参数代表三维⽹格的平⾯横轴坐标,y参数代表三维⽹格的平⾯纵轴坐
标,⽽z参数代表三维⽹格的⾼度坐标。最容易⽤来准备这种⽹格数据的是 np.meshgrid 函数,可以将两个⼀维的数组构造成⼀个⼆维
的⽹格:
In [3]: x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 40)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
Now let's look at this with a standard line-only contour plot:
下⾯我们可以绘制标准的轮廓线图表:
In [4]: plt.contour(X, Y, Z, colors='black');
Notice that by default when a single color is used, negative values are represented by dashed lines, and positive values
by solid lines. Alternatively, the lines can be color-coded by specifying a colormap with the cmap argument. Here, we'll
also specify that we want more lines to be drawn—20 equally spaced intervals within the data range:
图中值得注意的是,当使⽤单⾊绘制轮廓图时,虚线代表的是负数的数值,⽽实线代表的是正数。⽽轮廓线可以通过指定 cmap 参数来设
置线条的⾊图。下例中展⽰了使⽤⾊图且绘制了更多的轮廓线的例⼦,会在整个数据范围区域内等距分布有20条轮廓线:
In [5]: plt.contour(X, Y, Z, 20, cmap='RdGy');
Here we chose the RdGy (short for Red-Gray) colormap, which is a good choice for centered data. Matplotlib has a wide
range of colormaps available, which you can easily browse in IPython by doing a tab completion on the plt.cm
module:
上例中我们选择了 RdGy (Red-Gray的缩写)⾊图,这对于聚集的数据来说是⼀个不错的选择。Matplotlib有⼤量的颜⾊图可供使⽤,你
可以通过在IPython中对 plt.cm 模块使⽤TAB⾃动补全⽅法就可以看到:
plt.cm.<TAB>
Our plot is looking nicer, but the spaces between the lines may be a bit distracting. We can change this by switching to a
filled contour plot using the plt.contourf() function (notice the f at the end), which uses largely the same syntax
as plt.contour() .
上⾯的图看起来⽐第⼀幅图好多了,但是线条之间的空隙还是有点让⼈混淆。我们可以将上⾯的图改为填充轮廓图来解决这个问题,使⽤
plt.contourf() 函数(注意函数名最后有个f,代表填充fill),这个函数的语法基本上与 plt.contour() 保持⼀致。
Additionally, we'll add a plt.colorbar() command, which automatically creates an additional axis with labeled color
information for the plot:
并且我们加上了 plt.colorbar() 函数,这个函数会在图表边上创建⼀个颜⾊图例⽤以展⽰颜⾊所表⽰的数值区域:
In [6]: plt.contourf(X, Y, Z, 20, cmap='RdGy')
plt.colorbar();
The colorbar makes it clear that the black regions are "peaks," while the red regions are "valleys."
有了图例,很容易可以看出⿊⾊区域代表着“峰”,⽽红⾊区域代表这“⾕”。
One potential issue with this plot is that it is a bit "splotchy." That is, the color steps are discrete rather than continuous,
which is not always what is desired. This could be remedied by setting the number of contours to a very high number, but
this results in a rather inefficient plot: Matplotlib must render a new polygon for each step in the level. A better way to
handle this is to use the plt.imshow() function, which interprets a two-dimensional grid of data as an image.
上图有⼀个缺点,那就是图中颜⾊的阶梯是离散的⽽不是连续的,这通常不是我们想要的。我们可以通过设置很⾼的轮廓线数量来改善,
但是这会导致绘制图表的性能降低:Matplotlib必须在每个颜⾊阶梯上绘制⼀条新的轮廓多边形。更好的办法是使⽤ plt.imshow() 函
数,它会将⼀个⼆维的⽹格图表转换为⼀张图像。
The following code shows this:
下⾯的例⼦展⽰了该⽅法:
In [7]: plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower',
cmap='RdGy')
plt.colorbar()
plt.axis(aspect='image');
There are a few potential gotchas with imshow() , however:
plt.imshow() doesn't accept an x and y grid, so you must manually specify the extent [xmin, xmax, ymin, ymax]
of the image on the plot.
plt.imshow() by default follows the standard image array definition where the origin is in the upper left, not in the
lower left as in most contour plots. This must be changed when showing gridded data.
plt.imshow() will automatically adjust the axis aspect ratio to match the input data; this can be changed by
setting, for example, plt.axis(aspect='image') to make x and y units match.
然⽽,在使⽤ imshow() 的时候也有⼀些坑:
plt.imshow() 不接受x和y⽹格值作为参数,因此你需要⼿动指定extent参数[xmin, xmax, ymin, ymax]来设置图表的数据范围。
plt.imshow() 使⽤的是默认的图像坐标,即左上⻆坐标点是原点,⽽不是通常图表的左下⻆坐标点。这可以通过设置 origin 参
数来设置。
plt.imshow() 会⾃动根据输⼊数据调整坐标轴的⽐例;这可以通过参数来设置,例如, plt.axis(aspect='image') 能让x和
y轴的单位⼀致。
Finally, it can sometimes be useful to combine contour plots and image plots. For example, here we'll use a partially
transparent background image (with transparency set via the alpha parameter) and overplot contours with labels on
the contours themselves (using the plt.clabel() function):
最后,有时可能需要将轮廓图和图像结合起来。例如,下例中我们使⽤了半透明的背景图像(通过 alpha 参数设置透明度),然后在背
景图层之上绘制了轮廓图,并带有每个轮廓的数值标签(使⽤ plt.clabel() 函数绘制标签):
In [8]: contours = plt.contour(X, Y, Z, 3, colors='black')
plt.clabel(contours, inline=True, fontsize=8)
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower',
cmap='RdGy', alpha=0.5)
plt.colorbar();
The combination of these three functions— plt.contour , plt.contourf , and plt.imshow —gives nearly
limitless possibilities for displaying this sort of three-dimensional data within a two-dimensional plot. For more information
on the options available in these functions, refer to their docstrings. If you are interested in three-dimensional
visualizations of this type of data, see Three-dimensional Plotting in Matplotlib.
通过组合使⽤ plt.contour 、 plt.contourf 和 plt.imshow 这三个函数,基本可以满⾜我们绘制所有这种在⼆维图标上的三维数
据的需求。需要了解更多函数的参数信息,参考它们的⽂档字符串。如果你对于使⽤三维图表展⽰这种数据感兴趣,参⻅在matplotlib中创
建三维图表。
<
误差可视化 | ⽬录 | 直⽅图, 分桶和密度 >
Open in Colab
<
密度和轮廓图 | ⽬录 | ⾃定义图表图例 >
Open in Colab
Histograms, Binnings, and Density
直⽅图,分桶和密度
A simple histogram can be a great first step in understanding a dataset. Earlier, we saw a preview of Matplotlib's
histogram function (see Comparisons, Masks, and Boolean Logic), which creates a basic histogram in one line, once the
normal boiler-plate imports are done:
⼀个简单的直⽅图可以是我们开始理解数据集的第⼀步。前⾯我们看到了Matplotlib的直⽅图函数(参⻅⽐较,遮盖和布尔逻辑),我们可
以⽤⼀⾏代码绘制基础的直⽅图,当然⾸先需要将需要⽤的包导⼊notebook:
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
data = np.random.randn(1000)
In [2]: plt.hist(data);
The hist() function has many options to tune both the calculation and the display; here's an example of a more
customized histogram:
函数有很多的参数可以⽤来调整运算和展⽰;下⾯⼜⼀个更加个性化的直⽅图展⽰:
译者注:normed参数已经过时,此处对代码进⾏了相应修改,使⽤了替代的density参数。
hist()
In [3]: plt.hist(data, bins=30, density=True, alpha=0.5,
histtype='stepfilled', color='steelblue',
edgecolor='none');
The plt.hist docstring has more information on other customization options available. I find this combination of
histtype='stepfilled' along with some transparency alpha to be very useful when comparing histograms of
several distributions:
⽂档中有更多关于个性化参数的信息。作者发现联合使⽤ histtype='stepfilled' 和 alpha 参数设置透明度在对不同分
布的数据集进⾏⽐较展⽰时很有⽤:
plt.hist
In [4]: x1 = np.random.normal(0, 0.8, 1000)
x2 = np.random.normal(-2, 1, 1000)
x3 = np.random.normal(3, 2, 1000)
kwargs = dict(histtype='stepfilled', alpha=0.3, density=True, bins=40)
plt.hist(x1, **kwargs)
plt.hist(x2, **kwargs)
plt.hist(x3, **kwargs);
If you would like to simply compute the histogram (that is, count the number of points in a given bin) and not display it, the
np.histogram() function is available:
如果你只是需要计算直⽅图的数值(即每个桶的数据点数量)⽽不是展⽰图像, np.histogram() 函数可以完成这个⽬标:
In [5]: counts, bin_edges = np.histogram(data, bins=5)
print(counts)
[ 23 192 446 295
44]
Two-Dimensional Histograms and Binnings
⼆维直⽅图和分桶
Just as we create histograms in one dimension by dividing the number-line into bins, we can also create histograms in
two-dimensions by dividing points among two-dimensional bins. We'll take a brief look at several ways to do this here.
We'll start by defining some data—an x and y array drawn from a multivariate Gaussian distribution:
正如前⾯我们可以在⼀维上使⽤数值对应的直线划分桶⼀样,我们也可以在⼆维上使⽤数据对应的点来划分桶。本节我们介绍⼏种实现的
⽅法。⾸先定义数据集,从多元⾼斯分布中获得 x 和 y 数组:
In [6]: mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 10000).T
plt.hist2d : Two-dimensional histogram
plt.hist2d
:⼆维直⽅图
One straightforward way to plot a two-dimensional histogram is to use Matplotlib's plt.hist2d function:
绘制⼆维直⽅图最直接的⽅法是使⽤Matplotlib的 plt.hist2d 函数:
In [7]: plt.hist2d(x, y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('counts in bin')
Just as with plt.hist , plt.hist2d has a number of extra options to fine-tune the plot and the binning, which are
nicely outlined in the function docstring. Further, just as plt.hist has a counterpart in np.histogram ,
plt.hist2d has a counterpart in np.histogram2d , which can be used as follows:
类似 plt.hist , plt.hist2d 有许多额外的参数来调整分桶计算和图表展⽰,可以通过⽂档了解更多信息。⽽且, plt.hist 有
np.histogram , plt.hist2d 也有其对应的函数 np.histogram2d 。如下例:
In [8]: counts, xedges, yedges = np.histogram2d(x, y, bins=30)
For the generalization of this histogram binning in dimensions higher than two, see the np.histogramdd function.
如果要获得更⾼维度的分桶结果,参⻅ np.histogramdd 函数⽂档。
plt.hexbin : Hexagonal binnings
`plt.hexbin
:六⻆形分桶
The two-dimensional histogram creates a tesselation of squares across the axes. Another natural shape for such a
tesselation is the regular hexagon. For this purpose, Matplotlib provides the plt.hexbin routine, which will represents
a two-dimensional dataset binned within a grid of hexagons:
刚才的⼆维分桶是沿着坐标轴将每个桶分为正⽅形。另⼀个很⾃然的分桶形状就是正六边形。对于这个需求,Matplotlib提供了
plt.hexbin 函数,它也是在⼆维平⾯上分桶展⽰,不过每个桶(即图表上的每个数据格)将会是六边形:
In [9]: plt.hexbin(x, y, gridsize=30, cmap='Blues')
cb = plt.colorbar(label='count in bin')
plt.hexbin has a number of interesting options, including the ability to specify weights for each point, and to change
the output in each bin to any NumPy aggregate (mean of weights, standard deviation of weights, etc.).
有许多有趣的参数,包括能对每个点设置权重和将每个桶的输出数据结果改为任意的NumPy聚合结果(带权重的平均值,
带权重的标准差等)。
plt.hexbin
Kernel density estimation
核密度估计
Another common method of evaluating densities in multiple dimensions is kernel density estimation (KDE). This will be
discussed more fully in In-Depth: Kernel Density Estimation, but for now we'll simply mention that KDE can be thought of
as a way to "smear out" the points in space and add up the result to obtain a smooth function. One extremely quick and
simple KDE implementation exists in the scipy.stats package. Here is a quick example of using the KDE on this
data:
另外⼀个常⽤来统计多维数据密度的⼯具是核密度估计(KDE)。这部分内容将在深⼊:核密度估计⼀节中详细介绍。⽬前我们只需要知
道KDE被认为是⼀种可以⽤来填补数据的空隙并补充上平滑变化数据的⽅法就⾜够了。快速和简单的KDE算法已经在 scipy.stats 模块
中有了成熟的实现。下⾯我们就⼀个简单的例⼦来说明如何使⽤KDE和绘制相应的⼆维直⽅图:
In [10]: from scipy.stats import gaussian_kde
产⽣和处理数据,初始化
#
KDE
data = np.vstack([x, y])
kde = gaussian_kde(data)
在通⽤的⽹格中计算得到 的值
#
Z
xgrid = np.linspace(-3.5, 3.5, 40)
ygrid = np.linspace(-6, 6, 40)
Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
Z = kde.evaluate(np.vstack([Xgrid.ravel(), Ygrid.ravel()]))
将图表绘制成⼀张图像
#
plt.imshow(Z.reshape(Xgrid.shape),
origin='lower', aspect='auto',
extent=[-3.5, 3.5, -6, 6],
cmap='Blues')
cb = plt.colorbar()
cb.set_label("density")
KDE has a smoothing length that effectively slides the knob between detail and smoothness (one example of the
ubiquitous bias–variance trade-off). The literature on choosing an appropriate smoothing length is vast: gaussian_kde
uses a rule-of-thumb to attempt to find a nearly optimal smoothing length for the input data.
有着光滑的⻓度,可以在细节和光滑度中有效的进⾏调节(⼀个例⼦是⽅差偏差权衡)。这⽅⾯有⼤量的⽂献介绍:⾼斯核密度估计
使⽤了经验法则来寻找输⼊数据附近的优化光滑⻓度值。
KDE
gaussian_kde
Other KDE implementations are available within the SciPy ecosystem, each with its own strengths and weaknesses; see,
for example, sklearn.neighbors.KernelDensity and
statsmodels.nonparametric.kernel_density.KDEMultivariate . For visualizations based on KDE, using
Matplotlib tends to be overly verbose. The Seaborn library, discussed in Visualization With Seaborn, provides a much
more terse API for creating KDE-based visualizations.
其他的KDE实现也可以在SciPy中找到,每⼀种都有它的优点和缺点;参⻅ sklearn.neighbors.KernelDensity 和
statsmodels.nonparametric.kernel_density.KDEMultivariate 。要绘制基于KDE进⾏可视化的图表,Matplotlib写出的代
码会⽐较冗⻓。我们将在使⽤Seaborn进⾏可视化⼀节中介绍Seaborn库,它提供了更加简洁的⽅式⽤来绘制KDE图表。
<
密度和轮廓图 | ⽬录 | ⾃定义图表图例 >
Open in Colab
<
直⽅图, 分桶和密度 | ⽬录 | ⾃定义颜⾊条 >
Open in Colab
Customizing Plot Legends
⾃定义图标图例
Plot legends give meaning to a visualization, assigning meaning to the various plot elements. We previously saw how to
create a simple legend; here we'll take a look at customizing the placement and aesthetics of the legend in Matplotlib.
图例可以为可视化赋予实际含义,为不同的图标元素附上明确说明。我们前⾯看到了⼀些简单的图例创建例⼦;本⼩节中我们来介绍⼀下
在Matplotlib中⾃定义图例的位置和进⾏美化的⽅法。
The simplest legend can be created with the plt.legend() command, which automatically creates a legend for any
labeled plot elements:
可以使⽤ plt.legend() 函数来创建最简单的图例,这个函数能⾃动创建任何带有标签属性的图表元素的图例:
In [1]: import matplotlib.pyplot as plt
plt.style.use('classic')
In [2]: %matplotlib inline
import numpy as np
In [3]: x = np.linspace(0, 10, 1000)
fig, ax = plt.subplots()
ax.plot(x, np.sin(x), '-b', label='Sine')
ax.plot(x, np.cos(x), '--r', label='Cosine')
ax.axis('equal')
leg = ax.legend();
But there are many ways we might want to customize such a legend. For example, we can specify the location and turn
off the frame:
但除此之外还有很多能⾃定义图例的⽅法。例如,我们可以指定图例位置并且去除边框:
In [4]: ax.legend(loc='upper left', frameon=False)
fig
Out[4]:
We can use the ncol command to specify the number of columns in the legend:
我们可以使⽤ ncol 属性设置图例中每⾏的列数:
In [5]: ax.legend(frameon=False, loc='lower center', ncol=2)
fig
Out[5]:
We can use a rounded box ( fancybox ) or add a shadow, change the transparency (alpha value) of the frame, or
change the padding around the text:
还可以使⽤圆⻆⽅框( fancybox )或者增加阴影,设置⽅框的透明度(alpha值)或修改⽂字的边距:
In [6]: ax.legend(fancybox=True, framealpha=1, shadow=True, borderpad=1)
fig
Out[6]:
For more information on available legend options, see the plt.legend docstring.
要获取更多legend函数的可⽤选项信息,请参考 plt.legend 的⽂档字符串。
Choosing Elements for the Legend
选择设置图例的元素
As we have already seen, the legend includes all labeled elements by default. If this is not what is desired, we can finetune which elements and labels appear in the legend by using the objects returned by plot commands. The
plt.plot() command is able to create multiple lines at once, and returns a list of created line instances. Passing any
of these to plt.legend() will tell it which to identify, along with the labels we'd like to specify:
正如我们前⾯例⼦所⽰,绘制的图例默认包括所有带标签的元素。如果这不是想要的效果,我们可以调整哪些元素和标签会出现在图例当
中,这可以通过设置plot函数或⽅法返回的对象实现。 plt.plot 函数能够同时产⽣多条折线,然后将这些线条的实例列表返回。将其中
的部分实例传递到 plt.legend() 函数就能设置哪些线条会出现在图例中,再通过⼀个标签的列表指定图例的名称:
In [7]: y = np.sin(x[:, np.newaxis] + np.pi * np.arange(0, 2, 0.5))
lines = plt.plot(x, y)
是⼀个线条实例的列表
# lines
plt.legend(lines[:2], ['first', 'second']);
I generally find in practice that it is clearer to use the first method, applying labels to the plot elements you'd like to show
on the legend:
作者更加倾向于使⽤第⼀种⽅式,因为更加清晰。通过将标签应⽤在图表元素上,然后绘制到图例中:
In [8]: plt.plot(x, y[:, 0], label='first')
plt.plot(x, y[:, 1], label='second')
plt.plot(x, y[:, 2:])
plt.legend(framealpha=1, frameon=True);
Notice that by default, the legend ignores all elements without a label attribute set.
请注意默认情况下,legend会忽略所有不带标签的元素。
Legend for Size of Points
散点⼤⼩的图例
Sometimes the legend defaults are not sufficient for the given visualization. For example, perhaps you're be using the
size of points to mark certain features of the data, and want to create a legend reflecting this. Here is an example where
we'll use the size of points to indicate populations of California cities. We'd like a legend that specifies the scale of the
sizes of the points, and we'll accomplish this by plotting some labeled data with no entries:
某些情况下默认的图例不⾜以满⾜特定的可视化需求。例如,你在使⽤散点的⼤⼩来标记数据的某个特征,然后希望创建⼀个相应的图
例。下⾯的例⼦是加州城市⼈⼝的散点图,我们使⽤散点的⼤⼩表现该城市的⾯积,散点的颜⾊来表现城市的⼈⼝数量(⾃然对数值)。
我们希望使⽤⼀个图例来指明散点尺⼨的⽐例,同时⽤⼀个颜⾊条来说明⼈⼝数量,我们可以通过⾃定义绘制⼀些标签数据来实现尺⼨图
例:
译者注:新版Matplotlib已经取消aspect参数,此处改为使⽤新的 'scaled' 参数调⽤axis函数。
In [9]: import pandas as pd
cities = pd.read_csv('data/california_cities.csv')
提取我们感兴趣的数据
#
lat, lon = cities['latd'], cities['longd']
population, area = cities['population_total'], cities['area_total_km2']
绘制散点图,使⽤尺⼨代表⾯积,颜⾊代表⼈⼝,不带标签
#
plt.scatter(lon, lat, label=None,
c=np.log10(population), cmap='viridis',
s=area, linewidth=0, alpha=0.5)
plt.axis('scaled')
plt.xlabel('longitude')
plt.ylabel('latitude')
plt.colorbar(label='log$_{10}$(population)')
plt.clim(3, 7)
下⾯我们创建图例:
使⽤空列表绘制图例中的散点,使⽤不同⾯积和标签,带透明度
#
#
for area in [100, 300, 500]:
plt.scatter([], [], c='k', alpha=0.3, s=area,
label=str(area) + ' km$^2$')
plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='City Area')
plt.title('California Cities: Area and Population');
The legend will always reference some object that is on the plot, so if we'd like to display a particular shape we need to
plot it. In this case, the objects we want (gray circles) are not on the plot, so we fake them by plotting empty lists. Notice
too that the legend only lists plot elements that have a label specified.
之前的图例都关联着图表上的⼀些对象,因此如果我们需要展⽰图例的话我们⾸先需要绘制图表元素。在上例中,我们需要的图例对象
(灰⾊圆圈)不在图表上,因此我们采⽤绘制空列表的⽅式将它们仿造在图表上(实际上图上没有点),但是还是需要注意,只有那些带
标签的元素才会出现在图例中。
By plotting empty lists, we create labeled plot objects which are picked up by the legend, and now our legend tells us
some useful information. This strategy can be useful for creating more sophisticated visualizations.
通过绘制空列表,我们创建了三个带标签的对象,然后就可以出现在图例当中,这个图例就能表⽰出有关城市⾯积的相关信息。这个策略
在很多复杂可视化图表构建过程中都被⽤到。
Finally, note that for geographic data like this, it would be clearer if we could show state boundaries or other map-specific
elements. For this, an excellent choice of tool is Matplotlib's Basemap addon toolkit, which we'll explore in Geographic
Data with Basemap.
最后我们注意到这个图表实际上是⼀个地理位置图表,如果我们能在上⾯绘制州界线或其他地图相关的元素的话,会更加清晰。Matplotlib
提供了⼀个Basemap额外⼯具集来实现这个⽬标,我们会在使⽤Basemap创建地理位置图表中学习到它。
Multiple Legends
多重图例
Sometimes when designing a plot you'd like to add multiple legends to the same axes. Unfortunately, Matplotlib does not
make this easy: via the standard legend interface, it is only possible to create a single legend for the entire plot. If you
try to create a second legend using plt.legend() or ax.legend() , it will simply override the first one. We can
work around this by creating a new legend artist from scratch, and then using the lower-level ax.add_artist()
method to manually add the second artist to the plot:
有时候我们可能需要在同⼀个图表维度中设计多个图例。不幸的是,Matplotlib并没有提供很简单的⽅式实现:通过标准的 legend 接⼝,
只能在整张图表上创建⼀个图例。如果你试图使⽤ plt.legend() 或 ax.legend() 创建第⼆个图例,那么第⼆条语句创建的图例会覆
盖第⼀条语句创建的。我们只能通过从底层开始来创建⼀个新的图例artist这种⽅法来解决这个问题,然后使⽤ ax.add_artist() 的底
层⽅法⼿动将第⼆个作者加到图表上:
In [10]: fig, ax = plt.subplots()
lines = []
styles = ['-', '--', '-.', ':']
x = np.linspace(0, 10, 1000)
for i in range(4):
lines += ax.plot(x, np.sin(x - i * np.pi / 2),
styles[i], color='black')
ax.axis('equal')
指定第⼀个图例的线条和标签
#
ax.legend(lines[:2], ['line A', 'line B'],
loc='upper right', frameon=False)
⼿动创建第⼆个图例,并将作者添加到图表中
#
from matplotlib.legend import Legend
leg = Legend(ax, lines[2:], ['line C', 'line D'],
loc='lower right', frameon=False)
ax.add_artist(leg);
This is a peek into the low-level artist objects that comprise any Matplotlib plot. If you examine the source code of
ax.legend() (recall that you can do this with within the IPython notebook using ax.legend?? ) you'll see that the
function simply consists of some logic to create a suitable Legend artist, which is then saved in the legend_ attribute
and added to the figure when the plot is drawn.
上例展⽰了⽤来组成任何Matplotlib图表的底层artist对象的简单说明。如果你去查看 ax.legend() 的源代码(你可以通过IPython的
ax.legend?? 帮助⼯具做到),你可以看到这个⽅法包含了⽤来构建合适 Legend 的artist对象的逻辑,构建的对象被保存在
legend_ 属性当中,当绘制时被添加到图表上进⾏展⽰。
<
直⽅图, 分桶和密度 | ⽬录 | ⾃定义颜⾊条 >
Open in Colab
<
⾃定义图表图例 | ⽬录 | 多个⼦图表 >
Open in Colab
Customizing Colorbars
个性化颜⾊条
Plot legends identify discrete labels of discrete points. For continuous labels based on the color of points, lines, or
regions, a labeled colorbar can be a great tool. In Matplotlib, a colorbar is a separate axes that can provide a key for the
meaning of colors in a plot. Because the book is printed in black-and-white, this section has an accompanying online
supplement where you can view the figures in full color (https://github.com/jakevdp/PythonDataScienceHandbook). We'll
start by setting up the notebook for plotting and importing the functions we will use:
图例可以将离散的点标⽰为离散的标签。对于建⽴在不同颜⾊之上的连续的值(点线⾯)来说,标注了的颜⾊条是⾮常⽅便的⼯具。
Matplotlib的颜⾊条是独⽴于图表之外的⼀个类似于⽐⾊卡的图形,⽤来展⽰图表中不同颜⾊的数值含义。因为本书是使⽤⿊⽩打印的,本
节内容中的所有带⾊彩的图都可以在(https://github.com/wangyingsm/Python-Data-Science-Handbook)中找到。我们还是⾸先导⼊本节需
要的包和模块:
In [1]: import matplotlib.pyplot as plt
plt.style.use('classic')
In [2]: %matplotlib inline
import numpy as np
As we have seen several times throughout this section, the simplest colorbar can be created with the plt.colorbar
function:
通过 plt.colorbar 函数可以创建最简单的颜⾊条,在本节中我们会多次看到:
In [3]: x = np.linspace(0, 10, 1000)
I = np.sin(x) * np.cos(x[:, np.newaxis])
plt.imshow(I)
plt.colorbar();
We'll now discuss a few ideas for customizing these colorbars and using them effectively in various situations.
我们下⾯来讨论如何个性化颜⾊条以及在不同的场合⾼效的使⽤它们。
Customizing Colorbars
⾃定义颜⾊条
The colormap can be specified using the cmap argument to the plotting function that is creating the visualization:
颜⾊条可以通过 cmap 参数指定使⽤的⾊谱系统(或叫⾊图):
In [4]: plt.imshow(I, cmap='gray');
All the available colormaps are in the plt.cm namespace; using IPython's tab-completion will give you a full list of builtin possibilities:
所有可⽤的⾊图都可以在 plt.cm 模块中找到;在IPython中使⽤Tab⾃动补全功能能列出所有的⾊图列表:
plt.cm.<TAB>
But being able to choose a colormap is just the first step: more important is how to decide among the possibilities! The
choice turns out to be much more subtle than you might initially expect.
但是知道在哪⾥选择⾊图只是第⼀步:更重要的是在各种选项中选出合适的⾊图。这个选择⽐你预料的要微妙的多。
Choosing the Colormap
选择⾊图
A full treatment of color choice within visualization is beyond the scope of this book, but for entertaining reading on this
subject and others, see the article "Ten Simple Rules for Better Figures". Matplotlib's online documentation also has an
interesting discussion of colormap choice.
在可视化⽅案中选择颜⾊完整的介绍说明超出了本书的范围,如果你对这个课题和相关内容有兴趣,可以参考⽂章"绘制更漂亮图表的10个
简单规则"。Matplotlib的在线⽂档也有⼀章关于⾊图选择的有趣讨论。
Broadly, you should be aware of three different categories of colormaps:
Sequential colormaps: These are made up of one continuous sequence of colors (e.g., binary or viridis ).
Divergent colormaps: These usually contain two distinct colors, which show positive and negative deviations from a
mean (e.g., RdBu or PuOr ).
Qualitative colormaps: these mix colors with no particular sequence (e.g., rainbow or jet ).
通常来说,你应该注意以下三种不同类型的⾊图:
序列⾊图:这类型的⾊谱只包括⼀个连续序列的⾊系(例如 binary 或 viridis )。
分化⾊图:这类型的⾊谱包括两种独⽴的⾊系,这两种颜⾊有着⾮常⼤的对⽐度(例如 RdBu 或 PuOr )。
定性⾊图:这类型的⾊图混合了⾮特定连续序列的颜⾊(例如 rainbow 或 jet )。
The jet colormap, which was the default in Matplotlib prior to version 2.0, is an example of a qualitative colormap. Its
status as the default was quite unfortunate, because qualitative maps are often a poor choice for representing
quantitative data. Among the problems is the fact that qualitative maps usually do not display any uniform progression in
brightness as the scale increases.
⾊图,在Matplotlib 2.0版本之前都是默认的⾊图,是定性⾊图的⼀个例⼦。 jet 作为默认⾊图的位置其实有点尴尬,因为定性图通
常都不是对定量数据进⾏展⽰的好选择。原因是定性图通常都不能在范围增加时提供亮度的均匀增⻓。
jet
We can see this by converting the jet colorbar into black and white:
我们可以通过将 jet 颜⾊条转换为⿊⽩来看到这点:
In [5]: from matplotlib.colors import LinearSegmentedColormap
def grayscale_cmap(cmap):
"""
"""
cmap = plt.cm.get_cmap(cmap) #
colors = cmap(np.arange(cmap.N)) #
返回给定⾊图的灰度版本
将 颜⾊转换为灰度
参考
使⽤名称获取⾊图对象
将⾊图对象转为RGBA矩阵,形状为N×4
#
RGBA
#
http://alienryderflex.com/hsp.html
RGB_weight = [0.299, 0.587, 0.114] # RGB
luminance = np.sqrt(np.dot(colors[:, :3] ** 2, RGB_weight)) # RGB
colors[:, :3] = luminance[:, np.newaxis] #
#
return LinearSegmentedColormap.from_list(cmap.name + "_gray", colors, cmap.N)
三⾊的权重值
得到灰度值矩阵
返回相应的灰度值⾊图
平⽅值和权重的点积开平⽅根
def view_colormap(cmap):
"""
"""
cmap = plt.cm.get_cmap(cmap)
colors = cmap(np.arange(cmap.N))
将⾊图对应的灰度版本绘制出来
cmap = grayscale_cmap(cmap)
grayscale = cmap(np.arange(cmap.N))
fig, ax = plt.subplots(2, figsize=(6, 2),
subplot_kw=dict(xticks=[], yticks=[]))
ax[0].imshow([colors], extent=[0, 10, 0, 1])
ax[1].imshow([grayscale], extent=[0, 10, 0, 1])
In [6]: view_colormap('jet')
Notice the bright stripes in the grayscale image. Even in full color, this uneven brightness means that the eye will be
drawn to certain portions of the color range, which will potentially emphasize unimportant parts of the dataset. It's better
to use a colormap such as viridis (the default as of Matplotlib 2.0), which is specifically constructed to have an even
brightness variation across the range. Thus it not only plays well with our color perception, but also will translate well to
grayscale printing:
注意⼀下上⾯的灰度图中亮条纹的位置。即使在上述彩⾊图中,也出现了这种不规则的亮条纹,这会导致眼睛被区域中亮条纹所吸引,这
很可能造成阅读者被不重要的数据集部分⼲扰了。更好的选择是使⽤类似 viridis 这样的⾊图(Matplotlib 2.0后默认⾊图),它们被设
计为有着均匀的亮度变化。因此它们⽆论是在彩⾊图中还是在灰度图中都有着同样的亮度变化:
In [7]: view_colormap('viridis')
If you favor rainbow schemes, another good option for continuous data is the cubehelix colormap:
如果你更喜欢彩虹⽅案,另⼀个好的选择是使⽤ cubehelix ⾊图:
In [8]: view_colormap('cubehelix')
For other situations, such as showing positive and negative deviations from some mean, dual-color colorbars such as
RdBu (Red-Blue) can be useful. However, as you can see in the following figure, it's important to note that the positivenegative information will be lost upon translation to grayscale!
对于其他的情况,例如某种正负分布的数据集,双⾊颜⾊条如 RdBu (Red-Blue)会很常⽤。然⽽正如你从下⾯例⼦看到的,如果将双⾊
颜⾊条转化为灰度的话,正负或两级的信息就会丢失:
In [9]: view_colormap('RdBu')
We'll see examples of using some of these color maps as we continue.
后⾯我们会看到更多使⽤这些⾊图的例⼦。
There are a large number of colormaps available in Matplotlib; to see a list of them, you can use IPython to explore the
plt.cm submodule. For a more principled approach to colors in Python, you can refer to the tools and documentation
within the Seaborn library (see Visualization With Seaborn).
中有⼤量可⽤的⾊图;要看到它们的列表,你可以使⽤IPython来探索 plt.cm 模块。要在Python中更加正规的使⽤颜⾊,你可
以查看
库的⼯具和⽂档(参⻅使⽤Seaborn进⾏可视化)。
Matplotlib
Seaborn
Color limits and extensions
颜⾊限制和扩展
Matplotlib allows for a large range of colorbar customization. The colorbar itself is simply an instance of plt.Axes , so
all of the axes and tick formatting tricks we've learned are applicable. The colorbar has some interesting flexibility: for
example, we can narrow the color limits and indicate the out-of-bounds values with a triangular arrow at the top and
bottom by setting the extend property. This might come in handy, for example, if displaying an image that is subject to
noise:
允许你对颜⾊条进⾏⼤量的⾃定义。颜⾊条本⾝就是⼀个 plt.Axes 对象,因此所有轴和刻度定制的技巧都可以应⽤在上⾯。
颜⾊条也有着⼀些有趣的⾃定义⾏为:例如,我们可以缩⼩颜⾊的范围并且通过设置 extend 参数将超出范围之外的数值展⽰为顶部和底
部的三⻆箭头形状。这对于展⽰⼀些受到噪声⼲扰的数据时⾮常⽅便:
In [10]: # 在I数组中⼈为⽣成不超过1%的噪声
Matplotlib
speckles = (np.random.random(I.shape) < 0.01)
I[speckles] = np.random.normal(0, 3, np.count_nonzero(speckles))
plt.figure(figsize=(10, 3.5))
#
plt.subplot(1, 2, 1)
plt.imshow(I, cmap='RdBu')
plt.colorbar()
#
plt.subplot(1, 2, 2)
plt.imshow(I, cmap='RdBu')
plt.colorbar(extend='both')
plt.clim(-1, 1);
不考虑去除噪声时的颜⾊分布
设置去除噪声时的颜⾊分布
Notice that in the left panel, the default color limits respond to the noisy pixels, and the range of the noise completely
washes-out the pattern we are interested in. In the right panel, we manually set the color limits, and add extensions to
indicate values which are above or below those limits. The result is a much more useful visualization of our data.
注意到在左边的图表中,默认的颜⾊阈值是包括了噪声的,因此整体的条纹形状都被噪声数据冲刷淡化了。⽽右边的图表,我们⼿动设置
了颜⾊的阈值,并在绘制颜⾊条是加上了 extend 参数来表⽰超出阈值的数据。对于我们的数据来说,右图⽐左图要好的多。
Discrete Color Bars
离散颜⾊条
Colormaps are by default continuous, but sometimes you'd like to represent discrete values. The easiest way to do this is
to use the plt.cm.get_cmap() function, and pass the name of a suitable colormap along with the number of desired
bins:
⾊图默认是连续的,但是在某些情况下你可能需要展⽰离散值。最简单的⽅法是使⽤ plt.cm.get_cmap() 函数,在传递某个⾊图名称
的同时,还额外传递⼀个颜⾊分桶的数量值参数给该函数:
In [11]: plt.imshow(I, cmap=plt.cm.get_cmap('Blues', 6))
plt.colorbar()
plt.clim(-1, 1);
The discrete version of a colormap can be used just like any other colormap.
离散⾊图的使⽤⽅式和其他⾊图没有任何区别。
Example: Handwritten Digits
例⼦:⼿写数字
For an example of where this might be useful, let's look at an interesting visualization of some hand written digits data.
This data is included in Scikit-Learn, and consists of nearly 2,000 8 × 8 thumbnails showing various hand-written digits.
最后我们来看⼀个很有实⽤价值的例⼦,让我们实现对⼀些⼿写数字图像数据的可视化分析。这个数据包含在Sciki-Learn中,以供包含有
将近2,000张
⼤⼩的不同笔迹的⼿写数字缩略图。
8 ×8
For now, let's start by downloading the digits data and visualizing several of the example images with plt.imshow() :
⾸先,我们下载这个数据集,然后使⽤ plt.imshow() 将其中部分数据展⽰出来:
读取数字 的⼿写图像,然后使⽤
展⽰头64张缩略图
In [12]: #
0-5
Matplotlib
from sklearn.datasets import load_digits
digits = load_digits(n_class=6)
fig, ax = plt.subplots(8, 8, figsize=(6, 6))
for i, axi in enumerate(ax.flat):
axi.imshow(digits.images[i], cmap='binary')
axi.set(xticks=[], yticks=[])
Because each digit is defined by the hue of its 64 pixels, we can consider each digit to be a point lying in 64-dimensional
space: each dimension represents the brightness of one pixel. But visualizing relationships in such high-dimensional
spaces can be extremely difficult. One way to approach this is to use a dimensionality reduction technique such as
manifold learning to reduce the dimensionality of the data while maintaining the relationships of interest. Dimensionality
reduction is an example of unsupervised machine learning, and we will discuss it in more detail in What Is Machine
Learning?.
因为每个数字都是使⽤64个像素点渲染出来的,我们可以认为每个数字是⼀个64维空间中的点:每个维度代表这其中⼀个像素的灰度值。
但是要在图表中将这么⾼维度空间的联系可视化出来是⾮常困难的。有⼀种做法是使⽤降维技术,⽐⽅说使⽤流形学习来减少数据的维度
然⽽不会丢失数据中有效的信息。降维技术是⽆监督机器学习的⼀个例⼦,我们会在什么是机器学习?中更加详细的介绍它们。
Deferring the discussion of these details, let's take a look at a two-dimensional manifold learning projection of this digits
data (see In-Depth: Manifold Learning for details):
这些细节我们放在后⾯(参⻅深⼊:流形学习)讨论,我们来看⼀下将这些⼿写数字图像数据映射到⼆维流形学习当中:
使⽤
将⼿写数字图像映射到⼆维流形学习中
In [13]: #
Isomap
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
projection = iso.fit_transform(digits.data)
We'll use our discrete colormap to view the results, setting the ticks and clim to improve the aesthetics of the
resulting colorbar:
我们使⽤离散颜⾊条来展⽰结果,设置 ticks 和 clim 来进⼀步美化结果的颜⾊条:
绘制图表结果
In [14]: #
plt.scatter(projection[:, 0], projection[:, 1], lw=0.1,
c=digits.target, cmap=plt.cm.get_cmap('cubehelix', 6))
plt.colorbar(ticks=range(6), label='digit value')
plt.clim(-0.5, 5.5)
The projection also gives us some interesting insights on the relationships within the dataset: for example, the ranges of 5
and 3 nearly overlap in this projection, indicating that some hand written fives and threes are difficult to distinguish, and
therefore more likely to be confused by an automated classification algorithm. Other values, like 0 and 1, are more
distantly separated, and therefore much less likely to be confused. This observation agrees with our intuition, because 5
and 3 look much more similar than do 0 and 1.
我们从流形学习中的映射中可以观察到⼀些有趣现象:例如,图表中5和3有⼀些重叠的部分,这表⽰⼀些⼿写体中5和3是⽐较难以辨别
的,因此对于⾃动识别算法来说这是⽐较容易混淆的部分。⽽0和1,它们在图表中距离很远,这表⽰两者⽐较容易辨别,不太可能造成混
淆。这个图表分析与我们的直觉⼀致,因为5和3显然⽐0和1看起来更加接近。
We'll return to manifold learning and to digit classification in Chapter 5.
我们会在第五章再次看到流形学习和⼿写数字分类。
<
⾃定义图表图例 | ⽬录 | 多个⼦图表 >
Open in Colab
<
⾃定义颜⾊条 | ⽬录 | ⽂本和标注 >
Open in Colab
Multiple Subplots
多个⼦图表
Sometimes it is helpful to compare different views of data side by side. To this end, Matplotlib has the concept of
subplots: groups of smaller axes that can exist together within a single figure. These subplots might be insets, grids of
plots, or other more complicated layouts. In this section we'll explore four routines for creating subplots in Matplotlib.
在⼀些情况中,如果能将不同的数据图表并列展⽰,对于我们进⾏数据分析和⽐较会很有帮助。Matplotlib提供了⼦图表的概念来实现这⼀
点:单个图表中可以包括⼀组⼩的axes⽤来展⽰多个⼦图表。这些⼦图表可以是插图,⽹格状分布或其他更复杂的布局。在本节中我们会
介绍Matplotlib中⽤来构建⼦图表的四个函数。
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import numpy as np
plt.axes : Subplots by Hand
plt.axes
:⼿动构建⼦图表
The most basic method of creating an axes is to use the plt.axes function. As we've seen previously, by default this
creates a standard axes object that fills the entire figure. plt.axes also takes an optional argument that is a list of four
numbers in the figure coordinate system. These numbers represent [left, bottom, width, height] in the figure
coordinate system, which ranges from 0 at the bottom left of the figure to 1 at the top right of the figure.
构建axes作为⼦图表的最基础⽅法就是使⽤ plt.axes 函数。正如我们前⾯已经看到,默认情况下,这个函数够创建⼀个标准的axes对
象填满整个图表区域。 plt.axes 函数也可以接收⼀个可选的列表参数⽤来指定在axes在整个图表中的坐标点位置。列表中有四个数值分
别为 [left, bottom, width, height] (取值都是0-1),代表着⼦图表的左边、底部、宽度、⾼度在整个图表中左边、底部、宽
度、⾼度所占的⽐例值。
For example, we might create an inset axes at the top-right corner of another axes by setting the x and y position to 0.65
(that is, starting at 65% of the width and 65% of the height of the figure) and the x and y extents to 0.2 (that is, the size of
the axes is 20% of the width and 20% of the height of the figure):
例如,我们可以在距离左边和底部65%的位置,以插图的形式放置⼀个宽度和⾼度都是20%⼦图表,上述数值应该为 [0.65, 0.65,
0.2, 0.2] :
标准图表
In [2]: ax1 = plt.axes() #
ax2 = plt.axes([0.65, 0.65, 0.2, 0.2]) #
⼦图表
The equivalent of this command within the object-oriented interface is fig.add_axes() . Let's use this to create two
vertically stacked axes:
与上述等价的⾯向对象接⼝的语法是 fig.add_axes() 。我们使⽤这个⽅法来创建两个垂直堆叠的⼦图表:
获得
对象
In [3]: fig = plt.figure() #
figure
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4],
xticklabels=[], ylim=(-1.2, 1.2)) #
10%
50%
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4],
ylim=(-1.2, 1.2)) #
10%
10%
80%
40%
左边
底部
左边 底部
宽 ⾼
宽80% ⾼40%
x = np.linspace(0, 10)
ax1.plot(np.sin(x))
ax2.plot(np.cos(x));
We now have two axes (the top with no tick labels) that are just touching: the bottom of the upper panel (at position 0.5)
matches the top of the lower panel (at position 0.1 + 0.4).
这样我们就有两个⼦图表(上⾯的⼦图表没有x轴刻度),这两个⼦图表正好吻合:上⾯图表的底部是整个图表⾼度50%位置,⽽下⾯图表
的顶部也是整个图表的50%位置(0.1+0.4)。
plt.subplot : Simple Grids of Subplots
plt.subplot
:简单⽹格的⼦图表
Aligned columns or rows of subplots are a common-enough need that Matplotlib has several convenience routines that
make them easy to create. The lowest level of these is plt.subplot() , which creates a single subplot within a grid.
As you can see, this command takes three integer arguments—the number of rows, the number of columns, and the
index of the plot to be created in this scheme, which runs from the upper left to the bottom right:
将⼦图表的⾏与列对⻬是⼀个很常⻅的需求,因此Matplotlib提供了⼀些简单的函数来实现它们。这些函数当中最底层的是
plt.subplot() ,它会在⽹格中创建⼀个⼦图表。函数接受三个整数参数,⽹格⾏数,⽹格列数以及该⽹格⼦图表的序号(从左上⻆向
右下⻆递增):
In [4]: for i in range(1, 7):
plt.subplot(2, 3, i)
plt.text(0.5, 0.5, str((2, 3, i)),
fontsize=18, ha='center')
The command plt.subplots_adjust can be used to adjust the spacing between these plots. The following code
uses the equivalent object-oriented command, fig.add_subplot() :
plt.subplots_adjust
fig.add_subplot()
:
函数⽤来调整这些⼦图表之间的距离。下⾯的代码使⽤了与 plt.subplot() 等价的⾯向对象接⼝⽅法
In [5]: fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(1, 7):
ax = fig.add_subplot(2, 3, i)
ax.text(0.5, 0.5, str((2, 3, i)),
fontsize=18, ha='center')
We've used the hspace and wspace arguments of plt.subplots_adjust , which specify the spacing along the
height and width of the figure, in units of the subplot size (in this case, the space is 40% of the subplot width and height).
上例中我们指定了 plt.subplots_adjust 函数的 hspace 和 wspace 参数,它们代表这沿着⾼度和宽度⽅向⼦图表之间的距离,单
位是⼦图表的⼤⼩(在本例中,距离是⼦图表宽度和⾼度的40%)。
plt.subplots : The Whole Grid in One Go
plt.subplots
:⼀句代码设置所有⽹格⼦图表
The approach just described can become quite tedious when creating a large grid of subplots, especially if you'd like to
hide the x- and y-axis labels on the inner plots. For this purpose, plt.subplots() is the easier tool to use (note the
s at the end of subplots ). Rather than creating a single subplot, this function creates a full grid of subplots in a
single line, returning them in a NumPy array. The arguments are the number of rows and number of columns, along with
optional keywords sharex and sharey , which allow you to specify the relationships between different axes.
上⾯的⽅法当我们需要创建⼤量的⼦图表⽹格时会变得⾮常冗⻓乏味,特别是如果我们需要将内部图表x轴和y轴标签隐藏的情况下。因
此, plt.subplots 在这种情况下是⼀个合适的⼯具(注意末尾有个s)。这个函数会⼀次性创建所有的⽹格⼦图表,⽽不是单个⽹格,
并将它们在⼀个NumPy数组中返回。参数是⾏数和列数,还有两个可选的关键字参数 sharex 和 sharey ,可以让你指定不同⼦图表之
间的关联。
Here we'll create a 2 × 3 grid of subplots, where all axes in the same row share their y-axis scale, and all axes in the
same column share their x-axis scale:
下⾯我们来创建⼀个
2 ×3
⽹格的⼦图表,其中每⼀⾏的⼦图表共享它们的y轴,⽽每⼀列的⼦图表共享它们的x轴:
In [6]: fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')
Note that by specifying sharex and sharey , we've automatically removed inner labels on the grid to make the plot
cleaner. The resulting grid of axes instances is returned within a NumPy array, allowing for convenient specification of the
desired axes using standard array indexing notation:
注意上⾯我们设置了 sharex 和 sharey 之后,内部⼦图表的x轴和y轴的标签就⾃动被去掉了。返回值中ax是⼀个NumPy数组,⾥⾯含
有每⼀个⼦图表的实例,你可以使⽤NumPy索引的语法很简单的获得它们:
In [7]: # axes是⼀个2×3的数组,可以通过[row, col]进⾏索引访问
for i in range(2):
for j in range(3):
ax[i, j].text(0.5, 0.5, str((i, j)),
fontsize=18, ha='center')
fig
Out[7]:
In comparison to plt.subplot() , plt.subplots() is more consistent with Python's conventional 0-based
indexing.
并且相对于 plt.subplot , plt.subplots() 更复合Python从0开始进⾏索引的习惯。
plt.GridSpec : More Complicated Arrangements
plt.GridSpec
:更复杂的排列
To go beyond a regular grid to subplots that span multiple rows and columns, plt.GridSpec() is the best tool. The
plt.GridSpec() object does not create a plot by itself; it is simply a convenient interface that is recognized by the
plt.subplot() command. For example, a gridspec for a grid of two rows and three columns with some specified
width and height space looks like this:
当你需要⼦图表在⽹格中占据多⾏或多列时, plt.GridSpec() 正是你所需要的。 plt.GridSpec() 对象并不⾃⼰创建图表;它只是
⼀个可以被传递给 plt.subplot() 的参数。例如,⼀个两⾏三列并带有指定的宽度⾼度间隔的gridspec可以如下创建:
In [8]: grid = plt.GridSpec(2, 3, wspace=0.4, hspace=0.3)
From this we can specify subplot locations and extents using the familiary Python slicing syntax:
使⽤这个对象我们可以指定⼦图表的位置和占据的⽹格,仅需要使⽤熟悉的Python切⽚语法即可:
In [9]: plt.subplot(grid[0, 0])
plt.subplot(grid[0, 1:])
plt.subplot(grid[1, :2])
plt.subplot(grid[1, 2]);
This type of flexible grid alignment has a wide range of uses. I most often use it when creating multi-axes histogram plots
like the ones shown here:
这种灵活的⽹格对⻬控制⽅式有着⼴泛的应⽤。作者经常在需要创建多个直⽅图的联合图表中使⽤这种⽅法,如下例:
In [10]: # 构建⼆维正态分布数据
mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 3000).T
使⽤
创建⽹格并加⼊⼦图表
#
GridSpec
fig = plt.figure(figsize=(6, 6))
grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
main_ax = fig.add_subplot(grid[:-1, 1:])
y_hist = fig.add_subplot(grid[:-1, 0], xticklabels=[], sharey=main_ax)
x_hist = fig.add_subplot(grid[-1, 1:], yticklabels=[], sharex=main_ax)
在主图表中绘制散点图
#
main_ax.plot(x, y, 'ok', markersize=3, alpha=0.2)
分别在 轴和 轴⽅向绘制直⽅图
#
x
y
x_hist.hist(x, 40, histtype='stepfilled',
orientation='vertical', color='gray')
x_hist.invert_yaxis() # x
y
轴⽅向(右下)直⽅图倒转 轴⽅向
y_hist.hist(y, 40, histtype='stepfilled',
orientation='horizontal', color='gray')
y_hist.invert_xaxis() # y
x
轴⽅向(左上)直⽅图倒转 轴⽅向
This type of distribution plotted alongside its margins is common enough that it has its own plotting API in the Seaborn
package; see Visualization With Seaborn for more details.
这种沿着数据各⾃⽅向分布并绘制相应图表的需求是很通⽤的,因此在Seaborn包中它们有专⻔的API来实现;参⻅使⽤Seaborn进⾏可视
化来学习更多内容。
<
⾃定义颜⾊条 | ⽬录 | ⽂本和标注 >
Open in Colab
<
多个⼦图表 | ⽬录 | ⾃定义刻度 >
Open in Colab
Text and Annotation
⽂本和标注
Creating a good visualization involves guiding the reader so that the figure tells a story. In some cases, this story can be
told in an entirely visual manner, without the need for added text, but in others, small textual cues and labels are
necessary. Perhaps the most basic types of annotations you will use are axes labels and titles, but the options go beyond
this. Let's take a look at some data and how we might visualize and annotate it to help convey interesting information.
We'll start by setting up the notebook for plotting and importing the functions we will use:
创建⼀个优秀的可视化图表的关键在于引导读者,让他们能理解图表所讲述的故事。在⼀些情况下,这个故事可以通过纯图像的⽅式表
达,不需要额外添加⽂字,但是在另外⼀些情况中,图表需要⽂字的提⽰和标签才能将故事讲好。也许标注最基本的类型就是图表的标签
和标题,但是其中的选项参数却有很多。让我们在本节中使⽤⼀些数据来创建可视化图表并标注这些图表来表达这些有趣的信息。⾸先还
是需要将要⽤到的模块和包导⼊notebook:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import matplotlib as mpl
plt.style.use('seaborn-whitegrid')
import numpy as np
import pandas as pd
Example: Effect of Holidays on US Births
例⼦:节假⽇对美国出⽣率的影响
Let's return to some data we worked with earler, in "Example: Birthrate Data", where we generated a plot of average
births over the course of the calendar year; as already mentioned, that this data can be downloaded at
https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv.
本例中的数据是前⾯章节我们已经⽤过的(参⻅"例⼦:出⽣率"),当时我们对年内的平均出⽣数据创建了⼀个图表;就像前⾯已经提到
的,这个数据可以在 https://raw.githubusercontent.com/jakevdp/data-CDCbirths/master/births.csv 下载。
We'll start with the same cleaning procedure we used there, and plot the results:
我们先按照前⾯的⽅式进⾏同样的数据清洗程序,然后以图表展⽰这个结果:
In [2]: births = pd.read_csv('data/births.csv')
quartiles = np.percentile(births['births'], [25, 50, 75])
mu, sig = quartiles[1], 0.74 * (quartiles[2] - quartiles[0])
births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')
births['day'] = births['day'].astype(int)
births.index = pd.to_datetime(10000 * births.year +
100 * births.month +
births.day, format='%Y%m%d')
births_by_date = births.pivot_table('births',
[births.index.month, births.index.day])
births_by_date.index = [pd.datetime(2012, month, day)
for (month, day) in births_by_date.index]
In [3]: fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);
When we're communicating data like this, it is often useful to annotate certain features of the plot to draw the reader's
attention. This can be done manually with the plt.text / ax.text command, which will place text at a particular x/y
value:
当我们绘制了这样的图表来表达数据时,如果我们能对⼀些图表的特性作出标注来吸引读者的注意⼒通常是⾮常有帮助的。这可以通过调
⽤ plt.text 或 ax.text 函数来实现,它们可以在某个特定的x,y轴位置输出⼀段⽂字:
In [4]: fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax)
在折线的特殊位置标注⽂字
#
style = dict(size=10, color='gray')
ax.text('2012-1-1', 3950, "New Year's Day", **style)
ax.text('2012-7-4', 4250, "Independence Day", ha='center', **style)
ax.text('2012-9-4', 4850, "Labor Day", ha='center', **style)
ax.text('2012-10-31', 4600, "Halloween", ha='right', **style)
ax.text('2012-11-25', 4450, "Thanksgiving", ha='center', **style)
ax.text('2012-12-25', 3850, "Christmas ", ha='right', **style)
设置标题和 轴标签
#
y
ax.set(title='USA births by day of year (1969-1988)',
ylabel='average daily births')
设置 轴标签⽉份居中
#
x
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));
The ax.text method takes an x position, a y position, a string, and then optional keywords specifying the color, size,
style, alignment, and other properties of the text. Here we used ha='right' and ha='center' , where ha is short
for horizonal alignment. See the docstring of plt.text() and of mpl.text.Text() for more information on
available options.
⽅法接收 位置、 位置、⼀个字符串和额外可选的关键字参数可以⽤来设置颜⾊、⼤⼩、样式、对⻬等⽂本格式。上⾯我们使
和
,这⾥的 ha 是hirizonal alignment(⽔平对⻬)的缩写。要查阅更多的可⽤参数,请查看
和
的⽂档字符串内容。
ax.text
x
y
ha='right'
ha='center'
plt.text()
mpl.text.Text()
⽤了
Transforms and Text Position
转换和⽂本位置
In the previous example, we have anchored our text annotations to data locations. Sometimes it's preferable to anchor
the text to a position on the axes or figure, independent of the data. In Matplotlib, this is done by modifying the transform.
在刚才的例⼦中,我们将⽂字标注根据数据位置进⾏了定位。有些时候我们需要将⽂字标注独⽴于数据位置⽽根据图表位置进⾏定位。
Matplotlib通过转换完成这项⼯作。
Any graphics display framework needs some scheme for translating between coordinate systems. For example, a data
point at (x, y) = (1, 1) needs to somehow be represented at a certain location on the figure, which in turn needs to be
represented in pixels on the screen. Mathematically, such coordinate transformations are relatively straightforward, and
Matplotlib has a well-developed set of tools that it uses internally to perform them (these tools can be explored in the
matplotlib.transforms submodule).
任何的图形显⽰框架都需要在坐标系统之间进⾏转换的机制。例如,⼀个数据点位于
被转换为图表中的某个位置,进⽽转
换为屏幕上显⽰的像素。这样的坐标转换在数学上都相对来说⽐较直接,,⽽且Matplotlib提供了⼀系列的⼯具实现了转换(这些⼯具可以
在 matplotlib.transforms 模块中找到)。
(x, y) = (1, 1)
The average user rarely needs to worry about the details of these transforms, but it is helpful knowledge to have when
considering the placement of text on a figure. There are three pre-defined transforms that can be useful in this situation:
ax.transData : Transform associated with data coordinates
ax.transAxes : Transform associated with the axes (in units of axes dimensions)
fig.transFigure : Transform associated with the figure (in units of figure dimensions)
⼀般来说,⽤⼾很少需要关注这些转换的细节,但是当考虑将⽂本在图表上展⽰时,这些知识却⽐较有⽤。在这种情况中,下⾯三种定义
好的转换是⽐较有⽤的:
ax.transData :与数据坐标相关的转换
ax.tranAxes :与Axes尺⼨相关的转换(单位是axes的宽和⾼)
ax.tranFigure :与figure尺⼨相关的转换(单位是figure的宽和⾼)
Here let's look at an example of drawing text at various locations using these transforms:
下⾯我们来看看使⽤这些转换将⽂字写在图表中不同位置的例⼦:
In [5]: fig, ax = plt.subplots(facecolor='lightgray')
ax.axis([0, 10, 0, 10])
是默认的,这⾥写出来是为了明确对⽐
# transform=ax.transData
ax.text(1, 5, ". Data: (1, 5)", transform=ax.transData)
ax.text(0.5, 0.1, ". Axes: (0.5, 0.1)", transform=ax.transAxes)
ax.text(0.2, 0.2, ". Figure: (0.2, 0.2)", transform=fig.transFigure);
Note that by default, the text is aligned above and to the left of the specified coordinates: here the "." at the beginning of
each string will approximately mark the given coordinate location.
注意默认情况下,⽂字是在指定坐标位置靠左对⻬的:这⾥每个字符串开始的"."的位置就是每种转换的坐标位置。
The transData coordinates give the usual data coordinates associated with the x- and y-axis labels. The
transAxes coordinates give the location from the bottom-left corner of the axes (here the white box), as a fraction of
the axes size. The transFigure coordinates are similar, but specify the position from the bottom-left of the figure
(here the gray box), as a fraction of the figure size.
坐标给定的是通常使⽤的x和y轴坐标位置。 transAxes 坐标给定的是从axes左下⻆开始算起(⽩⾊区域)的坐标位置,使
⽤的是宽度和⻓度的占⽐。 transFigure 坐标类似,给定的是从figure左下⻆开始算起(灰⾊区域)的坐标位置,使⽤的也是宽度和⻓
度的占⽐。
transData
Notice now that if we change the axes limits, it is only the transData coordinates that will be affected, while the others
remain stationary:
因此如果我们改变了轴的最⼤⻓度,只有 transData 坐标会收到影响,其他两个还是保持在相同位置:
In [6]: ax.set_xlim(0, 2)
ax.set_ylim(-6, 6)
fig
Out[6]:
This behavior can be seen more clearly by changing the axes limits interactively: if you are executing this code in a
notebook, you can make that happen by changing %matplotlib inline to %matplotlib notebook and using
each plot's menu to interact with the plot.
这个变化可以通过动态改变轴的最⼤⻓度看的更加清楚:如果你在notebook执⾏这段代码,你可以将 %matplotlib inline 改
为 %matplotlib notebook ,然后使⽤图表的菜单来交互式的改变图表。
Arrows and Annotation
箭头和标注
Along with tick marks and text, another useful annotation mark is the simple arrow.
除了刻度标签和⽂字标签,另⼀种常⽤的标注是箭头。
Drawing arrows in Matplotlib is often much harder than you'd bargain for. While there is a plt.arrow() function
available, I wouldn't suggest using it: the arrows it creates are SVG objects that will be subject to the varying aspect ratio
of your plots, and the result is rarely what the user intended. Instead, I'd suggest using the plt.annotate() function.
This function creates some text and an arrow, and the arrows can be very flexibly specified.
在Matplotlib中绘制箭头通常⽐你想象的难得多。虽然有 plt.arrow() 函数,作者不建议使⽤它:这个函数绘制的箭头是⼀个SVG对
象,因此在图表使⽤不同的⽐例的情况会产⽣问题,结果通常不能让⽤⼾满意。因此,作者建议使⽤ plt.annotate() 函数。这个函数
会绘制⼀些⽂字以及⼀个箭头,并且箭头可以⾮常灵活的进⾏配置。
Here we'll use annotate with several of its options:
下⾯我们提供⼀些参数来使⽤ annotate 函数:
In [7]: %matplotlib inline
fig, ax = plt.subplots()
x = np.linspace(0, 20, 1000)
ax.plot(x, np.cos(x))
ax.axis('equal')
ax.annotate('local maximum', xy=(6.28, 1), xytext=(10, 4),
arrowprops=dict(facecolor='black', shrink=0.05))
ax.annotate('local minimum', xy=(5 * np.pi, -1), xytext=(2, -6),
arrowprops=dict(arrowstyle="->",
connectionstyle="angle3,angleA=0,angleB=-90"));
The arrow style is controlled through the arrowprops dictionary, which has numerous options available. These options
are fairly well-documented in Matplotlib's online documentation, so rather than repeating them here it is probably more
useful to quickly show some of the possibilities. Let's demonstrate several of the possible options using the birthrate plot
from before:
箭头的样式是使⽤ 箭头属性 字典值进⾏控制的,⾥⾯有很多可⽤的参数。这些参数在Matplotlib的在线⽂档中已经有了很详细的说明,因
此在这⾥就不将这部分内容重复介绍⼀遍了。我们在前⾯出⽣率图上再使⽤⼀些参数进⾏更多的说明:
In [8]: fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax)
为图表添加标注
#
ax.annotate("New Year's Day", xy=('2012-1-1', 4100), xycoords='data',
xytext=(50, -30), textcoords='offset points',
arrowprops=dict(arrowstyle="->",
connectionstyle="arc3,rad=-0.2"))
ax.annotate("Independence Day", xy=('2012-7-4', 4250), xycoords='data',
bbox=dict(boxstyle="round", fc="none", ec="gray"),
xytext=(10, -40), textcoords='offset points', ha='center',
arrowprops=dict(arrowstyle="->"))
ax.annotate('Labor Day', xy=('2012-9-4', 4850), xycoords='data', ha='center',
xytext=(0, -20), textcoords='offset points')
ax.annotate('', xy=('2012-9-1', 4850), xytext=('2012-9-7', 4850),
xycoords='data', textcoords='data',
arrowprops={'arrowstyle': '|-|,widthA=0.2,widthB=0.2', })
ax.annotate('Halloween', xy=('2012-10-31', 4600), xycoords='data',
xytext=(-80, -40), textcoords='offset points',
arrowprops=dict(arrowstyle="fancy",
fc="0.6", ec="none",
connectionstyle="angle3,angleA=0,angleB=-90"))
ax.annotate('Thanksgiving', xy=('2012-11-25', 4500), xycoords='data',
xytext=(-120, -60), textcoords='offset points',
bbox=dict(boxstyle="round4,pad=.5", fc="0.9"),
arrowprops=dict(arrowstyle="->",
connectionstyle="angle,angleA=0,angleB=80,rad=20"))
ax.annotate('Christmas', xy=('2012-12-25', 3850), xycoords='data',
xytext=(-30, 0), textcoords='offset points',
size=13, ha='right', va="center",
bbox=dict(boxstyle="round", alpha=0.1),
arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1));
设置图表标题和坐标轴标记
#
ax.set(title='USA births by day of year (1969-1988)',
ylabel='average daily births')
设置⽉份坐标居中显⽰
#
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));
ax.set_ylim(3600, 5400);
You'll notice that the specifications of the arrows and text boxes are very detailed: this gives you the power to create
nearly any arrow style you wish. Unfortunately, it also means that these sorts of features often must be manually
tweaked, a process that can be very time consuming when producing publication-quality graphics! Finally, I'll note that the
preceding mix of styles is by no means best practice for presenting data, but rather included as a demonstration of some
of the available options.
上图中箭头和⽂字框都⾮常详尽了:可以看出你⼏乎可以使⽤ plt.annotate 创建任何你想要的箭头样式。不幸的是,这意味着这种特
性都需要⼿⼯进⾏调整,因此如果需要获得印刷质量的图像,这将是⼀个⾮常耗费时间的⼯作。最后,必须指出,上述这种多种样式混合
的⽅式来展现数据肯定不是最佳实践,这⾥只是为了尽可能多的介绍可⽤的参数。
More discussion and examples of available arrow and annotation styles can be found in the Matplotlib gallery, in
particular the Annotation Demo.
更多关于Matplotlib的箭头和标注样式的讨论和例⼦可以访问Matplotlib gallery,特别是标注演⽰。
<
多个⼦图表 | ⽬录 | ⾃定义刻度 >
Open in Colab
<
⽂本和标注 | ⽬录 | ⾃定义matplotlib:配置和样式单 >
Open in Colab
Customizing Ticks
⾃定义刻度
Matplotlib's default tick locators and formatters are designed to be generally sufficient in many common situations, but are
in no way optimal for every plot. This section will give several examples of adjusting the tick locations and formatting for
the particular plot type you're interested in.
默认的刻度标志和格式被设计成能满⾜许多通⽤场景的需求,但是不会是所有图表的最佳选择。本节会介绍⼀些调整刻度位置和
格式的例⼦来说明⾃定义刻度的使⽤。
Matplotlib
Before we go into examples, it will be best for us to understand further the object hierarchy of Matplotlib plots. Matplotlib
aims to have a Python object representing everything that appears on the plot: for example, recall that the figure is
the bounding box within which plot elements appear. Each Matplotlib object can also act as a container of sub-objects: for
example, each figure can contain one or more axes objects, each of which in turn contain other objects
representing plot contents.
在介绍例⼦之前,我们应该加深对Matplotlib图表的对象层次的理解。Matplotlib的设计⽬标是展⽰在图表中的所有内容都会表达成为
Python的对象:例如,回忆前⾯我们介绍过 figure 指的是⽤来展⽰图表所有内容的⽅框。每个Matplotlib对象也被设计为其⼦对象的⼀
个容器:例如 figure 对象中可以包含⼀个或多个 axes 对象,每个 axes 对象都依次包含着其他⽤来展⽰图表的内容对象。
The tick marks are no exception. Each axes has attributes xaxis and yaxis , which in turn have attributes that
contain all the properties of the lines, ticks, and labels that make up the axes.
刻度也不例外。每个 axes 对象都有着属性 xaxis 和 yaxis ,表⽰x和y轴,其中包含着所有的属性⽤来指代轴的线、刻度和标签。
Major and Minor Ticks
主要的和次要的刻度
Within each axis, there is the concept of a major tick mark, and a minor tick mark. As the names would imply, major ticks
are usually bigger or more pronounced, while minor ticks are usually smaller. By default, Matplotlib rarely makes use of
minor ticks, but one place you can see them is within logarithmic plots:
在每个坐标轴上,都有主要的刻度和次要的刻度概念。正如名字指代的,主要刻度通常是⼤的和更多⽤到的,⽽次要刻度通常是⼩的。默
认Matplotlib很少使⽤次要刻度,但是在对数图表中我们可能会看到它们:
译者注:在Matplotlib 2.0之后,当axis的跨度过⼤时,默认次要刻度将会不再展⽰,因此,下⾯的代码经过了修改,加上了xlim和ylim参
数。
In [1]: import matplotlib.pyplot as plt
plt.style.use('classic')
%matplotlib inline
import numpy as np
In [2]: ax = plt.axes(xscale='log', yscale='log', xlim=[10e-5, 10e5], ylim=[10e-5, 10e5])
ax.grid();
We see here that each major tick shows a large tickmark and a label, while each minor tick shows a smaller tickmark with
no label.
我们看到每个主要刻度显⽰了⼀个⼤的标志和标签,⽽每个次要刻度显⽰了⼀个⼩的刻度标志没有标签。
These tick properties—locations and labels—that is, can be customized by setting the formatter and locator
objects of each axis. Let's examine these for the x axis of the just shown plot:
这些刻度属性,位置和标签,都可以使⽤每个轴的 formatter 和 locator 对象进⾏个性化设置。下⾯我们来查看⼀下x轴的相应对象:
In [3]: print(ax.xaxis.get_major_locator())
print(ax.xaxis.get_minor_locator())
<matplotlib.ticker.LogLocator object at 0x7f7f7419edd8>
<matplotlib.ticker.LogLocator object at 0x7f7f58fc8f98>
In [4]: print(ax.xaxis.get_major_formatter())
print(ax.xaxis.get_minor_formatter())
<matplotlib.ticker.LogFormatterSciNotation object at 0x7f7f740d1e80>
<matplotlib.ticker.LogFormatterSciNotation object at 0x7f7f58fc88d0>
We see that both major and minor tick labels have their locations specified by a LogLocator (which makes sense for a
logarithmic plot). Minor ticks, though, have their labels formatted by a NullFormatter : this says that no labels will be
shown.
我们看到主要和次要刻度的位置都是使⽤ LogLocator 来设置的(对于对数图表来说那是理所当然的)。然⽽次要刻度的标签的格式是
NullFormatter :这表⽰次要刻度不会显⽰标签。
译者注:新版Matplotlib已经修改,可以看到Formatter都统⼀成为了LogFormatterSciNotation,再根据图表实际情况选择是否展⽰标签。
We'll now show a few examples of setting these locators and formatters for various plots.
下⾯我们就可以开始介绍⼀些设置这些locator和formatter的例⼦了。
Hiding Ticks or Labels
隐藏刻度和标签
Perhaps the most common tick/label formatting operation is the act of hiding ticks or labels. This can be done using
plt.NullLocator() and plt.NullFormatter() , as shown here:
也许最常⻅的刻度/标签格式设置的操作是隐藏刻度或标签。这可以通过使⽤ plt.NullLocator() 和 plt.NullFormatter() 来设
置,如下例:
In [5]: ax = plt.axes()
ax.plot(np.random.rand(50))
ax.yaxis.set_major_locator(plt.NullLocator())
ax.xaxis.set_major_formatter(plt.NullFormatter())
Notice that we've removed the labels (but kept the ticks/gridlines) from the x axis, and removed the ticks (and thus the
labels as well) from the y axis. Having no ticks at all can be useful in many situations—for example, when you want to
show a grid of images. For instance, consider the following figure, which includes images of different faces, an example
often used in supervised machine learning problems (see, for example, In-Depth: Support Vector Machines):
注意上图中我们去除了x轴的标签(但是保留了刻度或⽹格线),y轴的刻度和标签都被去除了。图表中没有刻度和标签在很多情况下很有
⽤,例如,当你希望展⽰⼀个图像的⽹格。⽐⽅说,考虑下⾯的图表,包含着不同的头像,⼀个很常⻅的有监督机器学习问题(参⻅深
⼊:⽀持向量机):
In [6]: fig, ax = plt.subplots(5, 5, figsize=(5, 5))
fig.subplots_adjust(hspace=0, wspace=0)
从
载⼊头像数据集
#
scikit-learn
from sklearn.datasets import fetch_olivetti_faces
faces = fetch_olivetti_faces().images
for i in range(5):
for j in range(5):
ax[i, j].xaxis.set_major_locator(plt.NullLocator())
ax[i, j].yaxis.set_major_locator(plt.NullLocator())
ax[i, j].imshow(faces[10 * i + j], cmap="bone")
Notice that each image has its own axes, and we've set the locators to null because the tick values (pixel number in this
case) do not convey relevant information for this particular visualization.
注意上图中每张图像都有它⾃⼰的axes,我们将每⼀个axes的locator都设置为null因为这些刻度值(像素值)在这⾥并没有任何实际意
义。
Reducing or Increasing the Number of Ticks
减少或增加刻度的数量
One common problem with the default settings is that smaller subplots can end up with crowded labels. We can see this
in the plot grid shown here:
默认设置的⼀个常⻅问题是当⼦图表较⼩时,刻度标签可能会粘在⼀起。我们可以从下⾯例⼦看到:
In [7]: fig, ax = plt.subplots(4, 4, sharex=True, sharey=True)
Particularly for the x ticks, the numbers nearly overlap and make them quite difficult to decipher. We can fix this with the
plt.MaxNLocator() , which allows us to specify the maximum number of ticks that will be displayed. Given this
maximum number, Matplotlib will use internal logic to choose the particular tick locations:
特别是x轴,标签的数字就快重叠在⼀起了,这让这些标签难以认清。我们可以通过 plt.MaxNLocator() 来修正这点,⽤它可以设置最
⼤展⽰刻度的数量。Matplotlib会⾃⼰计算按照这个最⼤数量计算的刻度位置:
对 和 轴设置刻度最⼤数量
In [8]: #
x y
for axi in ax.flat:
axi.xaxis.set_major_locator(plt.MaxNLocator(3))
axi.yaxis.set_major_locator(plt.MaxNLocator(3))
fig
Out[8]:
This makes things much cleaner. If you want even more control over the locations of regularly-spaced ticks, you might
also use plt.MultipleLocator , which we'll discuss in the following section.
上图就清晰多了。如果你希望对于刻度位置进⾏更加精细的控制,你可以使⽤ plt.MultipleLocator ,我们会接下来讨论这个对象。
Fancy Tick Formats
复杂的刻度格式
Matplotlib's default tick formatting can leave a lot to be desired: it works well as a broad default, but sometimes you'd like
do do something more. Consider this plot of a sine and a cosine:
Matplotlib
表:
的默认刻度格式只能在很多常⻅情况下⼯作良好,但是在特殊情况下你会希望能够更多的进⾏个性化。考虑下⾯的正弦和余弦图
绘制正弦和余弦图表
In [9]: #
fig, ax = plt.subplots()
x = np.linspace(0, 3 * np.pi, 1000)
ax.plot(x, np.sin(x), lw=3, label='Sine')
ax.plot(x, np.cos(x), lw=3, label='Cosine')
设置⽹格、图例和轴极限
#
ax.grid(True)
ax.legend(frameon=False)
ax.axis('equal')
ax.set_xlim(0, 3 * np.pi);
There are a couple changes we might like to make. First, it's more natural for this data to space the ticks and grid lines in
multiples of π . We can do this by setting a MultipleLocator , which locates ticks at a multiple of the number you
provide. For good measure, we'll add both major and minor ticks in multiples of π/4:
这⾥有⼏个我们希望进⾏的改变。⾸先,如果刻度的间距和⽹格线是 的倍数会显得更加⾃然。我们可以通过 MultipleLocator 来设置
它,这个对象⽤来设置刻度的配置。为了更直观,我们设置主要刻度为 位置,设置次要刻度为 位置:
π
π
π
2
4
In [10]: ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
fig
Out[10]:
But now these tick labels look a little bit silly: we can see that they are multiples of π , but the decimal representation does
not immediately convey this. To fix this, we can change the tick formatter. There's no built-in formatter for what we want to
do, so we'll instead use plt.FuncFormatter , which accepts a user-defined function giving fine-grained control over
the tick outputs:
但是上图看起来有点傻:我们可以看出刻度确实是 的倍数,但是使⽤了⼩数的展⽰让它们看起来很奇怪。要修正这些标签,我们需要修
改刻度的formatter。在这种情况中,没有內建的formatter可以给我们使⽤,因此我们使⽤ plt.FuncFormatter ,这个对象能够接受⼀
个⽤⼾⾃定义的函数来提供对于刻度标签的精细控制:
π
In [11]: def format_func(value, tick_number):
# N pi/2
N = int(np.round(2 * value / np.pi))
if N == 0:
return "0" # 0
elif N == 1:
return r"$\frac{\pi}{2}$" # pi/2
elif N == 2:
return r"$\pi$" # pi
elif N % 2 > 0:
return r"$\frac{{%d}\pi}{2}$" %N # n*pi/2 n
else:
return r"${0}\pi$".format(N // 2) # n*pi n
是
的倍数
点
是奇数
是整数
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
fig
Out[11]:
This is much better! Notice that we've made use of Matplotlib's LaTeX support, specified by enclosing the string within
dollar signs. This is very convenient for display of mathematical symbols and formulae: in this case, "$\pi$" is
rendered as the Greek character π .
上图看起来好多了。注意到我们使⽤到了Matplotlib的LaTeX⽀持,使⽤美元符号将LaTeX字符串括起来。这是⽤来展⽰数学符号和公式的
简便⽅法:在这个例⼦中 "$\pi$" 被渲染成希腊字⺟ 。
π
The plt.FuncFormatter() offers extremely fine-grained control over the appearance of your plot ticks, and comes
in very handy when preparing plots for presentation or publication.
plt.FuncFomatter()
提供了对于图表刻度最⾼级的⾃定义和精细控制,并且当你需要创建需要印刷或出版的图表时⾮常⽅便。
Summary of Formatters and Locators
Formatter
和 Locator 总结
We've mentioned a couple of the available formatters and locators. We'll conclude this section by briefly listing all the
built-in locator and formatter options. For more information on any of these, refer to the docstrings or to the Matplotlib
online documentaion. Each of the following is available in the plt namespace:
我们已经介绍了⼀些formatter和locator。在最后我们通过将內建的locator和formatter参数列出来对本节做⼀个总结。要获得更多相关内
容,请参阅⽂档或Matplotlib的在线⽂档。下表中列出的对象在 plt 命名空间中都是有效的:
对象
Locator
NullLocator
FixedLocator
描述
⽆刻度
固定刻度位置
序号图表刻度 (例如x = range(len(y)))
LinearLocator
从最⼩到最⼤值的均匀分割刻度
LogLocator
从最⼩到最⼤值的对数分割刻度
MultipleLocator
某个基数的倍数刻度
MaxNLocator
刻度数量最⼤值
AutoLocator
默认的刻度数量最⼤值
AutoMinorLocator
默认的次要刻度
IndexLocator
描述
NullFormatter
⽆标签
IndexFormatter
从⼀个列表获得标签
FixedFormatter
从固定的字符串设置标签
FuncFormatter
使⽤⾃定义函数设置标签
FormatStrFormatter 使⽤⼀个格式化字符串设置标签
ScalarFormatter
默认的标量标签
LogFormatter
默认的对数标签
Formatter
对象
We'll see further examples of these through the remainder of the book.
在本书后续章节我们会看到更多的例⼦。
<
⽂本和标注 | ⽬录 | ⾃定义matplotlib:配置和样式单 >
Open in Colab
<
⾃定义刻度 | ⽬录 | 在matplotlib中创建三维图表 >
Open in Colab
Customizing Matplotlib: Configurations and Stylesheets
⾃定义matplotlib:配置和样式单
Matplotlib's default plot settings are often the subject of complaint among its users. While much is slated to change in the
2.0 Matplotlib release in late 2016, the ability to customize default settings helps bring the package inline with your own
aesthetic preferences.
默认的图表配置在⽤⼾当中经常被吐槽。虽然其中很多内容都预计会在2016年底Matplotlib 2.0版本中进⾏更改,但是个性化配置
的能⼒允许你按照⾃⼰的喜好来展⽰图表,因此还是有必要掌握的。
Matplotlib
Here we'll walk through some of Matplotlib's runtime configuration (rc) options, and take a look at the newer stylesheets
feature, which contains some nice sets of default configurations.
本节我们⾸先介绍基础的Matplotlib运⾏时配置,然后在看⼀看新的样式单特性,它们提供了很好的默认配置项。
Plot Customization by Hand
⼿动图表配置
Through this chapter, we've seen how it is possible to tweak individual plot settings to end up with something that looks a
little bit nicer than the default. It's possible to do these customizations for each individual plot. For example, here is a
fairly drab default histogram:
通过本节学习,我们会看到可以对每个独⽴的图表配置进⾏配置,直⾄展现的内容能够⽐默认样式要好。例如下⾯是⼀个很普通的默认直
⽅图:
In [1]: import matplotlib.pyplot as plt
plt.style.use('classic')
import numpy as np
%matplotlib inline
In [2]: x = np.random.randn(1000)
plt.hist(x);
We can adjust this by hand to make it a much more visually pleasing plot:
我们可以⼿动调整配置让上图看起来更加吸引⼈:
译者注:新版的Matplotlib中的subplot已经没有axisbg属性,下⾯代码使⽤了facecolor属性。
In [3]: # 使⽤灰⾊背景
ax = plt.axes(facecolor='#E6E6E6')
ax.set_axisbelow(True)
⽹格线使⽤⽩⾊实线
#
plt.grid(color='w', linestyle='solid')
隐藏
的边框线
#
axes
for spine in ax.spines.values():
spine.set_visible(False)
隐藏顶部和右边的刻度
#
ax.xaxis.tick_bottom()
ax.yaxis.tick_left()
淡化刻度和标签
#
ax.tick_params(colors='gray', direction='out')
for tick in ax.get_xticklabels():
tick.set_color('gray')
for tick in ax.get_yticklabels():
tick.set_color('gray')
设置直⽅的颜⾊和边缘⾊
#
ax.hist(x, edgecolor='#E6E6E6', color='#EE6666');
This looks better, and you may recognize the look as inspired by the look of the R language's ggplot visualization
package. But this took a whole lot of effort! We definitely do not want to have to do all that tweaking each time we create
a plot. Fortunately, there is a way to adjust these defaults once in a way that will work for all plots.
这看起来好多了,如果你对R语⾔熟悉的话,你会觉得上图看起来很像R中的ggplot可视化包的展⽰效果。但是这让我们付出了极⼤的努
⼒。我们当然不希望每次我们绘制图表时都需要做如上完整的配置⼯作。幸运的是,我们有⼀种⽅法可以设置⼀次就对所有的图表⽣效。
Changing the Defaults: rcParams
改变默认的 rcParams
Each time Matplotlib loads, it defines a runtime configuration (rc) containing the default styles for every plot element you
create. This configuration can be adjusted at any time using the plt.rc convenience routine. Let's see what it looks
like to modify the rc parameters so that our default plot will look similar to what we did before.
每次Matplotlib加载时,都会对绘制的每个图表元素载⼊默认的样式设置。这个配置可以在任何时候使⽤ plt.rc 来进⾏调整。让我们试着
修改rc的参数来让我们绘制的图表类似上⾯的效果。
We'll start by saving a copy of the current rcParams dictionary, so we can easily reset these changes in the current
session:
⾸先我们备份 rcParams 字典值,这样能允许我们⽅便的重置配置。
In [4]: IPython_default = plt.rcParams.copy()
Now we can use the plt.rc function to change some of these settings:
下⾯我们可以使⽤ plt.rc 函数来改变配置了:
In [5]: from matplotlib import cycler
colors = cycler('color',
['#EE6666', '#3388BB', '#9988DD',
'#EECC55', '#88BB44', '#FFBBBB'])
plt.rc('axes', facecolor='#E6E6E6', edgecolor='none',
axisbelow=True, grid=True, prop_cycle=colors)
plt.rc('grid', color='w', linestyle='solid')
plt.rc('xtick', direction='out', color='gray')
plt.rc('ytick', direction='out', color='gray')
plt.rc('patch', edgecolor='#E6E6E6')
plt.rc('lines', linewidth=2)
With these settings defined, we can now create a plot and see our settings in action:
设置了这些之后,我们可以绘制同样的直⽅图查看配置的效果:
In [6]: plt.hist(x);
Let's see what simple line plots look like with these rc parameters:
然后在绘制⼀幅使⽤现在rc配置的简单的折线图:
In [7]: for i in range(4):
plt.plot(np.random.rand(10))
I find this much more aesthetically pleasing than the default styling. If you disagree with my aesthetic sense, the good
news is that you can adjust the rc parameters to suit your own tastes! These settings can be saved in a .matplotlibrc file,
which you can read about in the Matplotlib documentation. That said, I prefer to customize Matplotlib using its
stylesheets instead.
作者认为这⽐默认样式要美观的多。如果你不同意作者的审美观,你仍然可以调整rc参数来满⾜你⾃⼰的品味。这些配置可以保存在⼀
个.matplotlibrc⽂件中,你可以在Matplotlib个性化⽂档中查阅相关说明。上⾯例⼦表明,使⽤Matplotlib⾃⼰的样式单进⾏个性化是更好的
⽅式。
Stylesheets
样式单
The version 1.4 release of Matplotlib in August 2014 added a very convenient style module, which includes a number
of new default stylesheets, as well as the ability to create and package your own styles. These stylesheets are formatted
similarly to the .matplotlibrc files mentioned earlier, but must be named with a .mplstyle extension.
年 ⽉发布的Matplotlib 1.4版本加⼊了⼀个⾮常⽅便的 style 模块,它包括了很多新的预设样式单,也允许你创建⾃⼰的样式单。这
些样式单的格式就像前⾯说过的.matplotlibrc⽂件⼀样,不过必须以.mplstyle作为扩展名。
2014 8
Even if you don't create your own style, the stylesheets included by default are extremely useful. The available styles are
listed in plt.style.available —here I'll list only the first five for brevity:
即使你不创建⾃⼰的样式,预设的样式单也⾮常有⽤。所有预设的样式可以使⽤ plt.style.available 列出,为了⻚⾯简短下⾯代码
仅列出前5个:
In [8]: plt.style.available[:5]
Out[8]: ['seaborn-darkgrid',
'seaborn-dark-palette',
'_classic_test',
'seaborn-paper',
'grayscale']
The basic way to switch to a stylesheet is to call
切换样式单最基本的操作是
plt.style.use('stylename')
But keep in mind that this will change the style for the rest of the session! Alternatively, you can use the style context
manager, which sets a style temporarily:
需要提醒的是这样做会改变整个会话后⾯所有的Matplotlib样式。当然你可以使⽤Python的with关键字和样式的上下⽂来临时使⽤某个样式
单:
with plt.style.context('stylename'):
make_a_plot()
Let's create a function that will make two basic types of plot:
下⾯我们创建⼀个函数⽤来绘制两种基础类型的图表:
In [9]: def hist_and_lines():
np.random.seed(0)
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
ax[0].hist(np.random.randn(1000))
for i in range(3):
ax[1].plot(np.random.rand(10))
ax[1].legend(['a', 'b', 'c'], loc='lower left')
We'll use this to explore how these plots look using the various built-in styles.
我们使⽤这个函数来看看不同的预设样式单展⽰⻛格。
Default style
默认样式
The default style is what we've been seeing so far throughout the book; we'll start with that. First, let's reset our runtime configuration to
the notebook default:
默认样式是我们本书中⼀直在使⽤的样式;我们从默认开始。⾸先我们将rc参数恢复成notebook默认值:
恢复默认值
In [10]: #
plt.rcParams.update(IPython_default);
/home/wangy/anaconda3/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning:
The examples.directory rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2. In the fu
ture, examples will be found relative to the 'datapath' directory.
self[key] = other[key]
/home/wangy/anaconda3/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning:
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
self[key] = other[key]
/home/wangy/anaconda3/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning:
The text.latex.unicode rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2.
self[key] = other[key]
/home/wangy/anaconda3/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning:
The verbose.fileo rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
self[key] = other[key]
/home/wangy/anaconda3/lib/python3.7/_collections_abc.py:841: MatplotlibDeprecationWarning:
The verbose.level rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
self[key] = other[key]
Now let's see how it looks:
下⾯我们看看它展⽰的图表:
In [11]: hist_and_lines()
FiveThiryEight style
538
样式
The fivethirtyeight style mimics the graphics found on the popular FiveThirtyEight website. As you can see here,
it is typified by bold colors, thick lines, and transparent axes:
538
样式模拟了流⾏的538⽹站的图像样式。正如你下⾯所⻅,它使⽤了前卫的颜⾊,粗线条和透明axes:
In [12]: with plt.style.context('fivethirtyeight'):
hist_and_lines()
ggplot
ggplot
The ggplot package in the R language is a very popular visualization tool. Matplotlib's ggplot style mimics the
default styles from that package:
R
语⾔中的 ggplot 包在数据科学可视化中⾮常受欢迎。Matplotlib的 ggplot 模拟了这个样式:
In [13]: with plt.style.context('ggplot'):
hist_and_lines()
*Bayesian Methods for Hackers( style
⿊客的⻉叶斯⽅法 样式
There is a very nice short online book called Probabilistic Programming and Bayesian Methods for Hackers; it features
figures created with Matplotlib, and uses a nice set of rc parameters to create a consistent and visually-appealing style
throughout the book. This style is reproduced in the bmh stylesheet:
有⼀本⾮常优秀的在线短篇书籍叫做⿊客的概率编程和⻉叶斯⽅法;书中展⽰的图表是使⽤Matplotlib创建的,这些图表使⽤了⼀套⾮常美
观和优秀的rc参数配置,并在整本书中保持了⼀致。Matplotlib中⽤ bmh 样式单模拟了这本书的样式配置:
In [14]: with plt.style.context('bmh'):
hist_and_lines()
Dark background
暗背景
For figures used within presentations, it is often useful to have a dark rather than light background. The
dark_background style provides this:
对于要在演⽰中使⽤的图表,通常使⽤暗背景⽐⽤亮背景更加有⽤。 dark_background 演⽰提供了这个需求:
In [15]: with plt.style.context('dark_background'):
hist_and_lines()
Grayscale
灰度
Sometimes you might find yourself preparing figures for a print publication that does not accept color figures. For this, the
grayscale style, shown here, can be very useful:
如果你需要⽤于⿊⽩印刷或打印的图表,那么 grayscale 灰度样式是你需要的:
In [16]: with plt.style.context('grayscale'):
hist_and_lines()
Seaborn style
Seaborn
样式
Matplotlib also has stylesheets inspired by the Seaborn library (discussed more fully in Visualization With Seaborn). As
we will see, these styles are loaded automatically when Seaborn is imported into a notebook. I've found these settings to
be very nice, and tend to use them as defaults in my own data exploration.
也有受到Seaborn库启发的样式单(详⻅使⽤Seaborn进⾏可视化)。正如我们下⾯看到的,当你将Seaborn载⼊notebook的时
候,这些样式会⾃动被装载。作者发现这些配置⾮常有⽤,并经常作为默认的样式使⽤。
Matplotlib
In [17]: import seaborn
hist_and_lines()
With all of these built-in options for various plot styles, Matplotlib becomes much more useful for both interactive
visualization and creation of figures for publication. Throughout this book, I will generally use one or more of these style
conventions when creating plots.
有了所有这些预设的选项和图表样式,Matplotlib成为了交互式可视化和创建可出版的图表的⾸选⼯具。在本书中,我们会使⽤⼀种或多种
样式来创建图表。
<
⾃定义刻度 | ⽬录 | 在matplotlib中创建三维图表 >
Open in Colab
<
⾃定义matplotlib:配置和样式单 | ⽬录 | 使⽤Basemap创建地理位置图表 >
Open in Colab
Three-Dimensional Plotting in Matplotlib
在matplotlib中创建三维图表
Matplotlib was initially designed with only two-dimensional plotting in mind. Around the time of the 1.0 release, some
three-dimensional plotting utilities were built on top of Matplotlib's two-dimensional display, and the result is a convenient
(if somewhat limited) set of tools for three-dimensional data visualization. three-dimensional plots are enabled by
importing the mplot3d toolkit, included with the main Matplotlib installation:
最开始被设计为仅⽀持⼆维的图表。到1.0版本发布左右,⼀些三维图表的⼯具在⼆维展⽰的基础上被创建了出来,结果就是
提供了⼀个⽅便的(同时也是有限的)的可⽤于三维数据可视化的⼀套⼯具。三维图表可以使⽤载⼊ mplot3d ⼯具包来激活,
这个包会随着Matplotlib⾃动安装:
Matplotlib
Matplotlib
In [1]: from mpl_toolkits import mplot3d
Once this submodule is imported, a three-dimensional axes can be created by passing the keyword
projection='3d' to any of the normal axes creation routines:
⼀旦模块被导⼊,三维axes就可以像其他普通axes⼀样通过关键字参数 projection='3d' 来创建:
In [7]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
In [8]: fig = plt.figure()
ax = plt.axes(projection='3d')
With this three-dimensional axes enabled, we can now plot a variety of three-dimensional plot types. Three-dimensional
plotting is one of the functionalities that benefits immensely from viewing figures interactively rather than statically in the
notebook; recall that to use interactive figures, you can use %matplotlib notebook rather than %matplotlib
inline when running this code.
三维axes激活后,我们可以在上⾯绘制不同的三维图表类型。三维图表在notebook中使⽤交互式图表展⽰会优于使⽤静态展⽰;回忆我们
前⾯介绍过,你可以使⽤ %matplotlib notebook ⽽不是 %matplotlib inline 来激活交互式展⽰模式。
Three-dimensional Points and Lines
三维的点和线
The most basic three-dimensional plot is a line or collection of scatter plot created from sets of (x, y, z) triples. In analogy
with the more common two-dimensional plots discussed earlier, these can be created using the ax.plot3D and
ax.scatter3D functions. The call signature for these is nearly identical to that of their two-dimensional counterparts,
so you can refer to Simple Line Plots and Simple Scatter Plots for more information on controlling the output. Here we'll
plot a trigonometric spiral, along with some points drawn randomly near the line:
三维图表中最基础的是使⽤(x, y, z)坐标定义的⼀根线或散点的集合。前⾯介绍过普通的⼆维图表,作为类⽐,使⽤ ax.plot3D 和
ax.scatter3D 函数可以创建三维折线和散点图。这两个函数的签名与⼆维的版本基本⼀致,你可以参考简单折线图和简单散点图来复
习⼀下这部分的内容。下⾯我们绘制⼀个三维中的三⻆螺旋,在线的附近在绘制⼀些随机的点:
In [9]: ax = plt.axes(projection='3d')
三维螺旋线的数据
#
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')
三维散点的数据
#
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');
Notice that by default, the scatter points have their transparency adjusted to give a sense of depth on the page. While the
three-dimensional effect is sometimes difficult to see within a static image, an interactive view can lead to some nice
intuition about the layout of the points.
注意默认情况下,图中的散点会有透明度的区别,⽤于体现在图中散点的深度。虽然三维效果在静态图像中难以显⽰,你可以使⽤交互式
的视图来获得更佳的三维直观效果。
Three-dimensional Contour Plots
三维轮廓图
Analogous to the contour plots we explored in Density and Contour Plots, mplot3d contains tools to create threedimensional relief plots using the same inputs. Like two-dimensional ax.contour plots, ax.contour3D requires all
the input data to be in the form of two-dimensional regular grids, with the Z data evaluated at each point. Here we'll show
a three-dimensional contour diagram of a three-dimensional sinusoidal function:
类似于我们在密度和轮廓图中介绍的内容, mplot3d 也包含着能够创建三维浮雕图像的⼯具。就像⼆维的 ax.contour 图表,
ax.contour3D 要求输⼊数据的格式是⼆维普通⽹格上计算得到的Z轴的数据值。下⾯我们展⽰⼀个三维的正弦函数轮廓图:
In [10]: def f(x, y):
return np.sin(np.sqrt(x ** 2 + y ** 2))
x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
In [11]: fig = plt.figure()
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
Sometimes the default viewing angle is not optimal, in which case we can use the view_init method to set the
elevation and azimuthal angles. In the following example, we'll use an elevation of 60 degrees (that is, 60 degrees above
the x-y plane) and an azimuth of 35 degrees (that is, rotated 35 degrees counter-clockwise about the z-axis):
有时候默认的视⻆⻆度不是最理想的,在这种情况下我们可以使⽤ view_init 函数来设置⽔平⻆和⽅位⻆。在下⾯的例⼦中,我们使⽤
的是60°的⽔平⻆(即以60°俯视x-y平⾯)和35°的⽅位⻆(即将z轴逆时针旋转35°):
In [12]: ax.view_init(60, 35)
fig
Out[12]:
Again, note that this type of rotation can be accomplished interactively by clicking and dragging when using one of
Matplotlib's interactive backends.
同样,注意到当使⽤Matplotlib交互式展⽰是,这样的旋转可以通过⿏标点击和拖拽来实现。
Wireframes and Surface Plots
框线图和表⾯图
Two other types of three-dimensional plots that work on gridded data are wireframes and surface plots. These take a grid
of values and project it onto the specified three-dimensional surface, and can make the resulting three-dimensional forms
quite easy to visualize. Here's an example of using a wireframe:
使⽤⽹格数据⽣成的三维图表还有框线图和表⾯图。这两种图表将⽹格数据投射到特定的三维表⾯,能够使得结果图像⾮常直观和具有说
服⼒。下⾯是⼀个框线图的例⼦:
In [13]: fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(X, Y, Z, color='black')
ax.set_title('wireframe');
A surface plot is like a wireframe plot, but each face of the wireframe is a filled polygon. Adding a colormap to the filled
polygons can aid perception of the topology of the surface being visualized:
表⾯图类似框线图,区别在于每个框线构成的多边形都使⽤颜⾊进⾏了填充。添加⾊图⽤于填充多边形能够让图形表⾯展⽰出来:
In [14]: ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='viridis', edgecolor='none')
ax.set_title('surface');
Note that though the grid of values for a surface plot needs to be two-dimensional, it need not be rectilinear. Here is an
example of creating a partial polar grid, which when used with the surface3D plot can give us a slice into the function
we're visualizing:
注意虽然每个颜⾊填充的表⾯都是⼆维的,但是表⾯的边缘不需要是直线构成的。下⾯的例⼦使⽤ surface3D 绘制了⼀个部分极坐标⽹
格,能够让我们切⼊到函数内部观察效果:
In [15]: r = np.linspace(0, 6, 20)
theta = np.linspace(-0.9 * np.pi, 0.8 * np.pi, 40)
r, theta = np.meshgrid(r, theta)
X = r * np.sin(theta)
Y = r * np.cos(theta)
Z = f(X, Y)
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='viridis', edgecolor='none');
Surface Triangulations
表⾯三⻆剖分
For some applications, the evenly sampled grids required by the above routines is overly restrictive and inconvenient. In
these situations, the triangulation-based plots can be very useful. What if rather than an even draw from a Cartesian or a
polar grid, we instead have a set of random draws?
在⼀些应⽤场合中,上⾯的这种均匀⽹格绘制的图表⽅式太过于局限和不⽅便。在这些情况下,三⻆剖分的图表可以派上⽤场。如果我们
并不是使⽤笛卡尔坐标系或极坐标系的⽹格来绘制三维图表,⽽是使⽤⼀组随机的点来绘制三维图表呢?
In [16]: theta = 2 * np.pi * np.random.random(1000)
r = 6 * np.random.random(1000)
x = np.ravel(r * np.sin(theta))
y = np.ravel(r * np.cos(theta))
z = f(x, y)
We could create a scatter plot of the points to get an idea of the surface we're sampling from:
有了上⾯的数据之后,我们可以使⽤它们来绘制⼀张散点图表现出样本所在表⾯的情况:
In [17]: ax = plt.axes(projection='3d')
ax.scatter(x, y, z, c=z, cmap='viridis', linewidth=0.5);
This leaves a lot to be desired. The function that will help us in this case is ax.plot_trisurf , which creates a
surface by first finding a set of triangles formed between adjacent points (remember that x, y, and z here are onedimensional arrays):
上图并未形象的表⽰出表⾯情况。这种情况下我们可以使⽤ ax.plot_trisurf 函数,它能⾸先根据我们的数据输⼊找到各点内在的三
⻆函数形式,然后绘制表⾯(注意的是这⾥的x,y,z是⼀维的数组):
In [18]: ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z,
cmap='viridis', edgecolor='none');
The result is certainly not as clean as when it is plotted with a grid, but the flexibility of such a triangulation allows for
some really interesting three-dimensional plots. For example, it is actually possible to plot a three-dimensional Möbius
strip using this, as we'll see next.
上图的结果很显然没有使⽤⽹格绘制表⾯图那么清晰,但是对于我们并不是使⽤函数构建数据样本(数据样本通常来⾃真实世界的采样)
的情况下,这能提供很⼤的帮助。例如我们下⾯会看到,能使⽤这种⽅法绘制⼀条三维的莫⽐乌斯环。
Example: Visualizing a Möbius strip
例⼦:绘制莫⽐乌斯环
A Möbius strip is similar to a strip of paper glued into a loop with a half-twist. Topologically, it's quite interesting because
despite appearances it has only a single side! Here we will visualize such an object using Matplotlib's three-dimensional
tools. The key to creating the Möbius strip is to think about it's parametrization: it's a two-dimensional strip, so we need
two intrinsic dimensions. Let's call them θ, which ranges from 0 to 2π around the loop, and w which ranges from -1 to 1
across the width of the strip:
莫⽐乌斯环是使⽤⼀条纸条,⼀端翻折后与另⼀端粘起来形成的环形。在拓扑学中这是⾮常有趣的⼀个形状,因为它只有⼀个⾯。我们下
⾯使⽤Matplotlib的三维⼯具绘制莫⽐乌斯环。创建莫⽐乌斯环的关键在于能参数化它:莫⽐乌斯环是⼀个⼆维的环状结构,因此我们需要
两个特定的维度。⼀个我们称为 ,取值范围是
表⽰整个环状,还有⼀个称为 ,取值范围是
表⽰纸带的宽度:
θ
0 → 2π
w
−1 → 1
In [19]: theta = np.linspace(0, 2 * np.pi, 30)
w = np.linspace(-0.25, 0.25, 8)
w, theta = np.meshgrid(w, theta)
Now from this parametrization, we must determine the (x, y, z) positions of the embedded strip.
有了这两个参数之后,我们需要确定莫⽐乌斯环上(x, y, z)坐标的位置。
Thinking about it, we might realize that there are two rotations happening: one is the position of the loop about its center
(what we've called θ), while the other is the twisting of the strip about its axis (we'll call this ϕ ). For a Möbius strip, we
must have the strip makes half a twist during a full loop, or Δϕ = Δθ/2 .
仔细思考⼀下,我们会发现在莫⽐乌斯环上有两个⾃转发⽣:⼀个是纸带绕环形中央位置的旋转(我们称为 ),另⼀个纸带绕着中间轴线
的旋转(我们称为 )。纸带中央位置旋转⼀整圈 时,纸带绕中间轴线旋转刚好半圈 ,我们将整个旋转均匀分布在纸带上时,就会有
。
θ
ϕ
Δϕ =
2π
π
Δθ
2
In [20]: phi = 0.5 * theta
Now we use our recollection of trigonometry to derive the three-dimensional embedding. We'll define r , the distance of
each point from the center, and use this to find the embedded (x, y, z) coordinates:
现在我们已经有了所有需要获得三维坐标值的参数了。我们定义 为每个坐标点距离环形中间的位置,使⽤它来计算最终
系的坐标值:
r
(x, y, z)
三维坐标
是坐标点距离环形中⼼的距离值
利⽤简单的三⻆函数知识算得 , , 坐标值
In [21]: # r
r = 1 + w * np.cos(phi)
#
x y z
x = np.ravel(r * np.cos(theta))
y = np.ravel(r * np.sin(theta))
z = np.ravel(w * np.sin(phi))
Finally, to plot the object, we must make sure the triangulation is correct. The best way to do this is to define the
triangulation within the underlying parametrization, and then let Matplotlib project this triangulation into the threedimensional space of the Möbius strip. This can be accomplished as follows:
最后,为了绘制对象,我们必须保证三⻆剖分是正确的。实现这个最好的⽅法是在底层的参数上⾯实现三⻆剖分,最后让Matplotlib将这个
三⻆剖分投射到三维空间中形成莫⽐乌斯环。下⾯的代码最终绘制图形:
In [22]: # 在底层参数的基础上进⾏三⻆剖分
from matplotlib.tri import Triangulation
tri = Triangulation(np.ravel(w), np.ravel(theta))
ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z, triangles=tri.triangles,
cmap='viridis', linewidths=0.2);
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(-1, 1);
Combining all of these techniques, it is possible to create and display a wide variety of three-dimensional objects and
patterns in Matplotlib.
结合这些技巧,能够为你提供在Matplotlib创建和展现⼤量三维对象和模式的能⼒。
<
⾃定义matplotlib:配置和样式单 | ⽬录 | 使⽤Basemap创建地理位置图表 >
Open in Colab
<
在matplotlib中创建三维图表 | ⽬录 | 使⽤Seaborn进⾏可视化 >
Open in Colab
Geographic Data with Basemap
使⽤Basemap创建地理位置图表
One common type of visualization in data science is that of geographic data. Matplotlib's main tool for this type of
visualization is the Basemap toolkit, which is one of several Matplotlib toolkits which lives under the mpl_toolkits
namespace. Admittedly, Basemap feels a bit clunky to use, and often even simple visualizations take much longer to
render than you might hope. More modern solutions such as leaflet or the Google Maps API may be a better choice for
more intensive map visualizations. Still, Basemap is a useful tool for Python users to have in their virtual toolbelts. In this
section, we'll show several examples of the type of map visualization that is possible with this toolkit.
数据科学中⼀个常⻅的可视化需求是处理地理数据。Matplotlib⽤来处理这类数据的主要⼯具是Basemap,这个⼯具是 mpl_toolkits 模
块中的⼯具包之⼀。必须承认,Basemap有⼀点难⽤,⽽且通常很简单的图表也会花费⽐你期望的更⻓的时间来渲染。更加现代的解决⽅
案例如leaflet或者⾕歌地图API对于密集的地图展⽰来说可能是更好的选择。然⽽Basemap依旧是Python⽤⼾⼯具包中有⽤的⼯具。在本节
中,我们会展⽰使⽤这个⼯具进⾏地图类型图表进⾏处理的⼀些例⼦。
Installation of Basemap is straightforward; if you're using conda you can type this and the package will be downloaded:
安装Basemap也是很容易的;如果你使⽤conda进⾏包管理的话,你可以输⼊下⾯的命令来下载并安装它:
$ conda install basemap
We add just a single new import to our standard boilerplate:
然后我们在标准的载⼊代码中加⼊⼀⾏:
译者注:安装完Basemap后,其依赖的proj4不会⾃动设置环境变量PROJ_LIB,你可以将这个环境变量设置到你的 $HOME/.bashrc
中,或像下⾯⼀样将环境变量设置到 <anaconda-home>/share/proj 。
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['PROJ_LIB']=os.environ['HOME'] + '/anaconda3/share/proj/' #
你可能需要修改
如果你的anaconda安装⽬录不在这⾥,
from mpl_toolkits.basemap import Basemap
Once you have the Basemap toolkit installed and imported, geographic plots are just a few lines away (the graphics in the
following also requires the PIL package in Python 2, or the pillow package in Python 3):
⼀旦你安装和载⼊了Basemap⼯具,地理图表距离你就只差⼏⾏代码了(下⾯的图像还需要Python 2的 PIL 包或Python 3的 pillow
包):
In [2]: plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None, lat_0=50, lon_0=-100)
m.bluemarble(scale=0.5);
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for in
tegers).
The meaning of the arguments to Basemap will be discussed momentarily.
Basemap
的参数含义我们很快会在后⾯介绍。
The useful thing is that the globe shown here is not a mere image; it is a fully-functioning Matplotlib axes that
understands spherical coordinates and which allows us to easily overplot data on the map! For example, we can use a
different map projection, zoom-in to North America and plot the location of Seattle. We'll use an etopo image (which
shows topographical features both on land and under the ocean) as the map background:
这⾥有⽤的是展⽰的地球不仅仅是⼀张图;它是⼀个具备完整功能的Matplotlib axes,并使⽤球⾯坐标系⽅便我们在地图上绘制我们的数
据。例如我们可以使⽤不同的地图投射,放⼤北美区域然后绘制西雅图的位置。我们会使⽤etopo地形图(它会绘制出陆地和海底的地形特
征)来作为地图的背景:
In [3]: fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
width=8E6, height=8E6,
lat_0=45, lon_0=-100,)
m.etopo(scale=0.5, alpha=0.5)
将西雅图的经纬度转换为地图上的 和 坐标⽤来绘制点和⽂字
#
x y
x, y = m(-122.3, 47.6)
plt.plot(x, y, 'ok', markersize=5)
plt.text(x, y, ' Seattle', fontsize=12);
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for in
tegers).
This gives you a brief glimpse into the sort of geographic visualizations that are possible with just a few lines of Python.
We'll now discuss the features of Basemap in more depth, and provide several examples of visualizing map data. Using
these brief examples as building blocks, you should be able to create nearly any map visualization that you desire.
上⾯地图可视化的例⼦给你⼀个简要的介绍,使⽤Python简单的⼏⾏代码可以显⽰地图的内容。我们下⾯开始什么讨论Basemap的特性,
当然也会提供⼀些例⼦对地图数据进⾏可视化。使⽤这些简单的例⼦作为基础,你应该可以创建任何地图类型的可视化图表。
Map Projections
地图投射
The first thing to decide when using maps is what projection to use. You're probably familiar with the fact that it is
impossible to project a spherical map, such as that of the Earth, onto a flat surface without somehow distorting it or
breaking its continuity. These projections have been developed over the course of human history, and there are a lot of
choices! Depending on the intended use of the map projection, there are certain map features (e.g., direction, area,
distance, shape, or other considerations) that are useful to maintain.
使⽤地图时⾸先需要考虑的是采⽤什么样的投射。你可能已经听过这个事实,那就是将⼀个球⾯地图,⽐如说地球投射到⼀个平⾯上时,
是不可能不产⽣扭曲以及破坏它的连续性的。⼈类的历史上出现过多种投射⽅式,因此你有许多不同的选择。取决于使⽤地图投射的⽬
的,你需要决定哪些地图的特性(例如⽅向、⾯积、距离、形状或其他)应该被保留下来。
The Basemap package implements several dozen such projections, all referenced by a short format code. Here we'll
briefly demonstrate some of the more common ones.
包实现了⼏⼗个这样的投射,所有的投射⽅式都使⽤缩写代码表⽰。我们这⾥会介绍⼏种最常⻅的投射⽅式。
Basemap
We'll start by defining a convenience routine to draw our world map along with the longitude and latitude lines:
⾸先定义⼀个函数来绘制带有经线和纬线的世界地图:
In [4]: from itertools import chain
def draw_map(m, scale=0.2):
#
m.shadedrelief(scale=scale)
绘制浮雕地图
经线和纬线绘制,并返回成两个字典
#
lats = m.drawparallels(np.linspace(-90, 90, 13))
lons = m.drawmeridians(np.linspace(-180, 180, 13))
获得经线和纬线的字典
#
key
lat_lines = chain(*(tup[1][0] for tup in lats.items()))
lon_lines = chain(*(tup[1][0] for tup in lons.items()))
all_lines = chain(lat_lines, lon_lines)
设置所有经线和纬线样式
#
for line in all_lines:
line.set(linestyle='-', alpha=0.3, color='w')
Cylindrical projections
圆柱投射
The simplest of map projections are cylindrical projections, in which lines of constant latitude and longitude are mapped
to horizontal and vertical lines, respectively. This type of mapping represents equatorial regions quite well, but results in
extreme distortions near the poles. The spacing of latitude lines varies between different cylindrical projections, leading to
different conservation properties, and different distortion near the poles. In the following figure we show an example of
the equidistant cylindrical projection, which chooses a latitude scaling that preserves distances along meridians. Other
cylindrical projections are the Mercator ( projection='merc' ) and the cylindrical equal area ( projection='cea' )
projections.
圆柱投射是最简单的地图投射⽅式,也就是将地球想象成⼀个圆柱,展开后经线和纬线分别处于垂直和⽔平⽅向,形成均匀的⽹格。这种
投射⽅式在⾚道附近⼯作的很理想,但是在两级会造成严重的扭曲。纬度线之间的间隔因不同圆柱投射⽽不同,导致保留的属性也不同,
最终两级附近的扭曲也不相同。在下图中,我们采⽤了等距圆柱投射,这种投射使⽤纬度线将⼦午线分割成等距的部分。其他的圆柱投射
还有墨卡托投射( projection='merc' )和圆柱等⾯积投射(`projection='cea')。
In [5]: fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='cyl', resolution=None,
llcrnrlat=-90, urcrnrlat=90,
llcrnrlon=-180, urcrnrlon=180, )
draw_map(m)
The additional arguments to Basemap for this view specify the latitude ( lat ) and longitude ( lon ) of the lower-left
corner ( llcrnr ) and upper-right corner ( urcrnr ) for the desired map, in units of degrees.
上图中Basemap的其他额外参数指定了图左下⻆( llcrnr )的纬度( lat )和经度( lon )以及图右上⻆( urcrnr )的纬度和经
度,单位是⻆度。
Pseudo-cylindrical projections
伪圆柱投射
Pseudo-cylindrical projections relax the requirement that meridians (lines of constant longitude) remain vertical; this can
give better properties near the poles of the projection. The Mollweide projection ( projection='moll' ) is one
common example of this, in which all meridians are elliptical arcs. It is constructed so as to preserve area across the
map: though there are distortions near the poles, the area of small patches reflects the true area. Other pseudocylindrical projections are the sinusoidal ( projection='sinu' ) and Robinson ( projection='robin' )
projections.
伪圆柱投射不再要求⼦午线(经度线)是垂直的;这种⽅式在两级附近也能获得更好的结果。摩尔维特投射是其中⼀个常⽤的⽅法,这种
⽅式中所有的⼦午线都是⼀段椭圆弧线。在摩尔维特投射中,⾯积属性得以保留:虽然在两极附近图像会产⽣扭曲,但是地图上的⼩块⾯
积对应着真实的⾯积。其他的伪圆柱投射包括正弦投射( projection='sinu' )和罗宾森映射( projection='robin' )。
In [6]: fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='moll', resolution=None,
lat_0=0, lon_0=0)
draw_map(m)
The extra arguments to Basemap here refer to the central latitude ( lat_0 ) and longitude ( lon_0 ) for the desired
map.
上图中Basemap的额外参数包括图像中央的维度( lat_0 )和经度( lon_0 )。
Perspective projections
透视投射
Perspective projections are constructed using a particular choice of perspective point, similar to if you photographed the
Earth from a particular point in space (a point which, for some projections, technically lies within the Earth!). One common
example is the orthographic projection ( projection='ortho' ), which shows one side of the globe as seen from a
viewer at a very long distance. As such, it can show only half the globe at a time. Other perspective-based projections
include the gnomonic projection ( projection='gnom' ) and stereographic projection ( projection='stere' ).
These are often the most useful for showing small portions of the map.
透视投射需要⾸先选择⼀个透视点,相当于你在太空中给地球拍照⽚(对于某些投射来说,透视点也可能在地球内)。其中最常⻅的例⼦
是正投射( projection='ortho' ),你可以认为是在太空中⼀个很远的距离展⽰地球的半个球⾯。其他的透视投射包括球⼼投射
( projection='gnom' )和⽴体投射( projection='stere' )。
Here is an example of the orthographic projection:
下⾯是正投射的⼀个例⼦:
In [7]: fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None,
lat_0=50, lon_0=0)
draw_map(m);
Conic projections
圆锥投射
A Conic projection projects the map onto a single cone, which is then unrolled. This can lead to very good local
properties, but regions far from the focus point of the cone may become very distorted. One example of this is the
Lambert Conformal Conic projection ( projection='lcc' ), which we saw earlier in the map of North America. It
projects the map onto a cone arranged in such a way that two standard parallels (specified in Basemap by lat_1 and
lat_2 ) have well-represented distances, with scale decreasing between them and increasing outside of them. Other
useful conic projections are the equidistant conic projection ( projection='eqdc' ) and the Albers equal-area
projection ( projection='aea' ). Conic projections, like perspective projections, tend to be good choices for
representing small to medium patches of the globe.
圆锥投射将地图投影到⼀个圆锥上,然后再展开。这会提供⾮常优秀的局部属性,但是远离圆锥焦点的区域会严重变形。⾥⾯最常⽤的是
兰伯特等⻆圆锥投影( projection='lcc' ),我们在前⾯的北美地图例⼦中已经使⽤到。它使⽤两条平⾏纬度线(在Basemap中⽤
lat_1 和 lat_2 指定),它们的距离和球⾯保持⼀致,因此两条纬度线之间的⽐例将会减⼩⽽之外的⽐例将会增⼤。其他圆锥投射包括
等距圆锥映射( projection='eqdc' )和阿尔伯思投射或称等积圆锥投射( projection='aea' )。圆锥投射就像透视投射⼀样,
适⽤于⼩区域或中等区域的球⾯投射。
In [8]: fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
lon_0=0, lat_0=50, lat_1=45, lat_2=55,
width=1.6E7, height=1.2E7)
draw_map(m)
Other projections
其他投射
If you're going to do much with map-based visualizations, I encourage you to read up on other available projections,
along with their properties, advantages, and disadvantages. Most likely, they are available in the Basemap package. If
you dig deep enough into this topic, you'll find an incredible subculture of geo-viz geeks who will be ready to argue
fervently in support of their favorite projection for any given application!
如果你打算做很多与地图相关的可视化⼯作,作者⿎励你继续阅读其他的投射⽅法,以及它们的特性和优缺点。它们可以在Basemap包在
线⽂档中找到。如果你⾜够深⼊这个课题,你会发现这个领域存在着亚⽂化,很多的地图可视化爱好者会与你热情的讨论它们最喜欢的投
射⽅法以及适合的应⽤场景。
Drawing a Map Background
绘制地图背景
Earlier we saw the bluemarble() and shadedrelief() methods for projecting global images on the map, as well
as the drawparallels() and drawmeridians() methods for drawing lines of constant latitude and longitude. The
Basemap package contains a range of useful functions for drawing borders of physical features like continents, oceans,
lakes, and rivers, as well as political boundaries such as countries and US states and counties. The following are some of
the available drawing functions that you may wish to explore using IPython's help features:
前⾯我们看到 bluemarble() 和 shadedrelief() ⽅法可以⽤于将地球图像投射到我们的图中,还有使⽤ drawparallels() 和
drawmeridians() ⽅法在图上绘制固定的纬度线和经度线。Basemap包还提供了⼀套的函数⽤来绘制地图的物理边界如⼤陆海岸线、
海洋、湖泊以及河流,还可以绘制政治边界线如国家边境线和美国州县边界。下⾯列出了其中⼀些你可以通过IPython的帮助⼯具进⼀步了
解的函数:
物理边界和江湖河海
drawcoastlines() : 绘制⼤陆海岸线
drawlsmask() : 将陆地和海洋绘制成⼀幅遮罩图像
drawmapboundary() : 绘制地图边界,包括海洋颜⾊
drawrivers() : 在地图上绘制河流
fillcontinents() : 将⼤陆填充颜⾊,可选将湖泊填充颜⾊
政治边界
drawcountries() : 绘制国境线
drawstates() : 绘制美国州边界
drawcounties() : 绘制美国县界
地图特性
drawgreatcircle() : 两点间绘制⼤圆
drawparallels() : 绘制固定纬度线
drawmeridians() : 绘制固定经度先
drawmapscale() : 在地图上绘制⽐例尺
地图⻛格
bluemarble() : 使⽤NASA蓝⾊⼤理⽯⻛格将图像投射在地图上
shadedrelief() : 使⽤带形状的浮雕⻛格将图像投射在地图上
etopo() : 使⽤带地形的浮雕⻛格将图像投射在地图上
warpimage() : 将⽤⼾⾃定义的图像投射在地图上
For the boundary-based features, you must set the desired resolution when creating a Basemap image. The
resolution argument of the Basemap class sets the level of detail in boundaries, either 'c' (crude), 'l' (low),
'i' (intermediate), 'h' (high), 'f' (full), or None if no boundaries will be used. This choice is important: setting
high-resolution boundaries on a global map, for example, can be very slow.
对于边界相关的特性,你必须在创建Basemap图像时设置需要的清晰度。使⽤ resolution 属性设置边界细节等级,可选项包
括 'c' (粗糙)、 'l' (低分辨率)、 'i' (中等分辨率)、 'h' (⾼分辨率)、 'f' (完整分辨率)或 None (如果不使⽤边界
线)。这个选项是很重要的:在全球地图上使⽤⾼分辨率边界选项会⾮常慢。
Here's an example of drawing land/sea boundaries, and the effect of the resolution parameter. We'll create both a lowand high-resolution map of Scotland's beautiful Isle of Skye. It's located at 57.3°N, 6.2°W, and a map of 90,000 × 120,000
kilometers shows it well:
下⾯是绘制⼤陆海洋边界线的例⼦,以及不同分辨率参数的效果。我们创建苏格兰斯凯岛的低分辨率和⾼分辨率的地图。它的经纬度是北
纬57.3°,西经6.2°,展⽰的⼤⼩是90,000 × 120,000千⽶:
In [9]: fig, ax = plt.subplots(1, 2, figsize=(12, 8))
低分辨率和⾼分辨率
for i, res in enumerate(['l', 'h']): #
m = Basemap(projection='gnom', lat_0=57.3, lon_0=-6.2, #
width=90000, height=120000, resolution=res, ax=ax[i]) #
m.fillcontinents(color="#FFDDCC", lake_color='#DDEEFF') #
m.drawmapboundary(fill_color="#DDEEFF") #
m.drawcoastlines() #
ax[i].set_title("resolution='{0}'".format(res));
绘制海岸线
海洋填充⾊
球⼼投射
绘制在相应的⼦图表中
陆地填充⾊和湖泊填充⾊
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:5: MatplotlibDeprecationWarni
ng:
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc
instead.
"""
Notice that the low-resolution coastlines are not suitable for this level of zoom, while high-resolution works just fine. The
low level would work just fine for a global view, however, and would be much faster than loading the high-resolution
border data for the entire globe! It might require some experimentation to find the correct resolution parameter for a given
view: the best route is to start with a fast, low-resolution plot and increase the resolution as needed.
你可以看到左图的低分辨率图像显然不适合这个⽐例的地图展⽰,⽽⾼分辨率图像正适合。低分辨率适合在全球级别的地图中使⽤,⽽且
它也会⽐⾼分辨率设置快很多。对于某个地图适合使⽤哪种分辨率可能需要⼀些实验:最好的办法是⾸先采⽤快速的低分辨率图像,然后
根据需要增加分辨率。
Plotting Data on Maps
在地图上绘制数据
Perhaps the most useful piece of the Basemap toolkit is the ability to over-plot a variety of data onto a map background.
For simple plotting and text, any plt function works on the map; you can use the Basemap instance to project latitude
and longitude coordinates to (x, y) coordinates for plotting with plt , as we saw earlier in the Seattle example.
⼯具集提供的最有⽤的功能应该是允许⽤⼾在地图背景上绘制数据图表的能⼒。对于简单绘图和⽂字来说,所有的 plt 函数都
可以在地图上良好⼯作;你可以使⽤ Basemap 实例将经纬度坐标投射成 (x, y) 坐标,然后使⽤ plt 进⾏绘制,就像我们前⾯在西雅
图例⼦中⻅到的那样。
Basemap
In addition to this, there are many map-specific functions available as methods of the Basemap instance. These work
very similarly to their standard Matplotlib counterparts, but have an additional Boolean argument latlon , which if set to
True allows you to pass raw latitudes and longitudes to the method, rather than projected (x, y) coordinates.
除此之外, Basemap 实例还提供了很多地图相关的函数。这些函数的使⽤⽅式⾮常类似于它们的Matplotlib对应函数,但是都接收⼀个额
外的布尔参数 latlon ,当设置为 True 时允许你传递原始经纬度坐标给函数,⽽不是映射后的 (x, y) 坐标。
Some of these map-specific methods are:
部分的地图相关⽅法有:
绘制轮廓线或填充轮廓
imshow() : 绘制图像
pcolor() / pcolormesh() : 绘制伪彩⾊的⽹格
plot() : 绘制线条和刻度
scatter() : 绘制散点和刻度
quiver() : 绘制向量
barbs() : 绘制⻛⽻图
drawgreatcircle() : 绘制⼤圆
contour() / contourf() :
We'll see some examples of a few of these as we continue. For more information on these functions, including several
example plots, see the online Basemap documentation.
我们后续会在例⼦中看到其中⼀些函数。要获得这些函数的更多资料和例⼦图表,参⻅在线Basemap⽂档。
Example: California Cities
例⼦:加利福利亚城市
Recall that in Customizing Plot Legends, we demonstrated the use of size and color in a scatter plot to convey
information about the location, size, and population of California cities. Here, we'll create this plot again, but using
Basemap to put the data in context.
我们在⾃定义图表图例展⽰使⽤散点的⼤⼩和颜⾊绘制涵盖了加利福利亚位置、⾯积和⼈⼝信息图表的⽅法。下⾯我们再次绘制这个图
表,但是使⽤Basemap将数据展⽰在地图上。
We start with loading the data, as we did before:
⾸先载⼊数据:
In [10]: import pandas as pd
cities = pd.read_csv('data/california_cities.csv')
将我们需要的数据列提取出来
#
lat = cities['latd'].values
lon = cities['longd'].values
population = cities['population_total'].values
area = cities['area_total_km2'].values
Next, we set up the map projection, scatter the data, and then create a colorbar and legend:
然后我们设置地图映射⽅式,将数据散点绘制出来,并创建颜⾊条和图例:
In [11]: # 1. 绘制地图背景
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution='h', # 兰伯特等⻆圆锥投射,⾼分辨率
lat_0=37.5, lon_0=-119, # 加利福利亚经纬度
width=1E6, height=1.2E6) # 展⽰⼤⼩ 1000000 × 1200000 千⽶
m.shadedrelief() # 阴影浮雕效果
m.drawcoastlines(color='gray') # 绘制海岸线
m.drawcountries(color='gray') # 绘制国境线
m.drawstates(color='gray') # 绘制州边界线
# 2. 城市数据散点绘制,颜⾊表⽰⼈⼝,⼤⼩表⽰⾯积
m.scatter(lon, lat, latlon=True, # latlon设置为True,使⽤原始经纬度数据
c=np.log10(population), s=area, # 颜⾊设置为⼈⼝的常⽤对数,⼤⼩设置为城市⾯积
cmap='Reds', alpha=0.5)
绘制颜⾊条及刻度
# 3.
plt.colorbar(label=r'$\log_{10}({\rm population})$')
plt.clim(3, 7)
绘制图例和刻度
#
for a in [100, 300, 500]:
plt.scatter([], [], c='k', alpha=0.5, s=a,
label=str(a) + ' km$^2$')
plt.legend(scatterpoints=1, frameon=False,
labelspacing=1, loc='lower left');
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:5: MatplotlibDeprecationWarni
ng:
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc
instead.
"""
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:8: MatplotlibDeprecationWarni
ng:
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc
instead.
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:9: MatplotlibDeprecationWarni
ng:
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc
instead.
if __name__ == '__main__':
This shows us roughly where larger populations of people have settled in California: they are clustered near the coast in
the Los Angeles and San Francisco areas, stretched along the highways in the flat central valley, and avoiding almost
completely the mountainous regions along the borders of the state.
上图给我们展⽰了加利福利亚州⼤概的⼈⼝分布情况:基本上聚集在海岸线边缘的洛杉矶和旧⾦⼭区域附近,并且两者沿着中间的平坦⼭
⾕地带延伸,在州界沿线的⼭地区域很少⼈居聚集。
Example: Surface Temperature Data
例⼦:表⾯温度数据
As an example of visualizing some more continuous geographic data, let's consider the "polar vortex" that hit the eastern
half of the United States in January of 2014. A great source for any sort of climatic data is NASA's Goddard Institute for
Space Studies. Here we'll use the GIS 250 temperature data, which we can download using shell commands (these
commands may have to be modified on Windows machines). The data used here was downloaded on 6/12/2016, and the
file size is approximately 9MB:
下⾯我们使⽤2014年⼀⽉侵袭了半个美国东部地区的极地涡旋现象来作为例⼦说明更加连续性的地理数据可视化⽅法。⽓象数据的⼀个⾮
常优秀的来源是NASA⼽达德太空研究所。这⾥我们会采⽤GIS 250温度数据,我们可以使⽤shell命令来下载(windows⽤⼾可能需要修改
相应命令)。本书使⽤的数据是在2016年6⽉12⽇下载的,⽂件⼤⼩⼤约9MB:
译者注:该数据⽂件在下⾯的地址已经⽆法找到,可以到https://www.kompulsa.com/climate-data-mirror/ 镜像站点进⾏下载,本仓库也直
接在notebooks/data⽬录下提供了bzip2压缩包,直接解压可⽤。下⾯的读取数据路径相应修改为data/gistemp250.nc。
In [12]: # !curl -O http://data.giss.nasa.gov/pub/gistemp/gistemp250.nc.gz
# !gunzip gistemp250.nc.gz
The data comes in NetCDF format, which can be read in Python by the netCDF4 library. You can install this library as
shown here
数据使⽤了NetCDF格式,使⽤Python的 netCDF4 库可以对该格式进⾏操作。使⽤下⾯的命令能安装该软件包:
$ conda install netcdf4
We read the data as follows:
然后读取整个数据集:
In [13]: from netCDF4 import Dataset
data = Dataset('data/gistemp250.nc')
The file contains many global temperature readings on a variety of dates; we need to select the index of the date we're
interested in—in this case, January 15, 2014:
⽂件含有很多不同⽇期的全球温度数据;我们需要从中选出我们感兴趣的那天,本例中是2014年1⽉15⽇:
In [14]: from netCDF4 import date2index
from datetime import datetime
timeindex = date2index(datetime(2014, 1, 15),
data.variables['time'])
Now we can load the latitude and longitude data, as well as the temperature anomaly for this index:
现在我们能够读取经纬度数据了,当然不能缺少的是这个⽇期的异常温度数据:
In [15]: lat = data.variables['lat'][:]
lon = data.variables['lon'][:]
lon, lat = np.meshgrid(lon, lat)
temp_anomaly = data.variables['tempanomaly'][timeindex]
Finally, we'll use the pcolormesh() method to draw a color mesh of the data. We'll look at North America, and use a
shaded relief map in the background. Note that for this data we specifically chose a divergent colormap, which has a
neutral color at zero and two contrasting colors at negative and positive values. We'll also lightly draw the coastlines over
the colors for reference:
最后,我们使⽤ pcolormesh() 函数来绘制数据⽹格的颜⾊。我们观察的是北美地区,使⽤阴影浮雕地图作为背景。并且我们特别选择
了⼀种⼆分的⾊图来展现数据,这样零值时会是⼀种中⽴颜⾊,⽽两端是两种对⽐⾊分别表⽰正值和负值。我们还会在⾊块之上使⽤淡⾊
绘制出海岸线作为参考:
In [16]: fig = plt.figure(figsize=(10, 8))
m = Basemap(projection='lcc', resolution='c', #
width=8E6, height=8E6, #
8000000 × 8000000
lat_0=45, lon_0=-100,) #
m.shadedrelief(scale=0.5) #
m.pcolormesh(lon, lat, temp_anomaly,
latlon=True, cmap='RdBu_r') #
plt.clim(-8, 8)
m.drawcoastlines(color='lightgray') #
阴影浮雕
兰伯特等⻆圆锥映射,边界线粗糙
地图⼤⼩
千⽶
映射中⼼经纬度
绘制数据⽹格颜⾊
绘制海岸线
plt.title('January 2014 Temperature Anomaly')
plt.colorbar(label='temperature anomaly (°C)');
/home/wangy/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: MatplotlibDeprecationWarni
ng:
The dedent function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use inspect.cleandoc
instead.
after removing the cwd from sys.path.
The data paints a picture of the localized, extreme temperature anomalies that happened during that month. The eastern
half of the United States was much colder than normal, while the western half and Alaska were much warmer. Regions
with no recorded temperature show the map background.
上图绘制了北美地区在那个⽉中的极端异常温度数据情况。整个美东地区⽐正常情况温度低了许多,⽽美西地区及阿拉斯加却⽐正常情况
暖和了许多。没有记录的温度数据区域图中显⽰的是地图的背景。
<
在matplotlib中创建三维图表 | ⽬录 | 使⽤Seaborn进⾏可视化 >
Open in Colab
<
使⽤Basemap创建地理位置图表 | ⽬录 | 更多资源 >
Open in Colab
Visualization with Seaborn
使⽤Seaborn进⾏可视化
Matplotlib has proven to be an incredibly useful and popular visualization tool, but even avid users will admit it often
leaves much to be desired. There are several valid complaints about Matplotlib that often come up:
已经证明了⾃⼰是⼀个异常有⽤和流⾏的可视化⼯具,但即使是狂热的⽤⼾也承认它有很多不⾜的地⽅。下⾯是⼀些经常被提出
来关于
的吐槽:
Matplotlib
Matplotlib
Prior to version 2.0, Matplotlib's defaults are not exactly the best choices. It was based off of MATLAB circa 1999,
and this often shows.
Matplotlib's API is relatively low level. Doing sophisticated statistical visualization is possible, but often requires a lot
of boilerplate code.
Matplotlib predated Pandas by more than a decade, and thus is not designed for use with Pandas DataFrame s. In
order to visualize data from a Pandas DataFrame , you must extract each Series and often concatenate them
together into the right format. It would be nicer to have a plotting library that can intelligently use the DataFrame
labels in a plot.
在2.0版之前,Matplotlib默认值不总是最好的选择。因为它是基于MATLAB circa 1999的,这⼀点经常会被吐槽。
Matplotlib的API相对来说⽐较底层,当然可以⽤来创建复杂的统计图表,但是经常需要撸很多冗⻓的代码。
Matplotlib⽐Pandas开发早了超过10年,然⽽却还不⽀持直接使⽤Pandas的 DataFrame 。为了将Pandas的 DataFrame 可视化,你
必须将每个 Series 提取出来并组合成合适的格式。如果能够提供直接使⽤ DataFrame 的标签进⾏图表可视化的⼯具会⽅便的多。
An answer to these problems is Seaborn. Seaborn provides an API on top of Matplotlib that offers sane choices for plot
style and color defaults, defines simple high-level functions for common statistical plot types, and integrates with the
functionality provided by Pandas DataFrame s.
上述问题可以通过Seaborn得到解答。Seaborn在Matplotlib之上提供了⼀套API,包括合理的默认样式和颜⾊,为通⽤统计报表设计的简单
的⾼层函数和对Pandas的 DataFrame 的集成。
To be fair, the Matplotlib team is addressing this: it has recently added the plt.style tools discussed in Customizing
Matplotlib: Configurations and Style Sheets, and is starting to handle Pandas data more seamlessly. The 2.0 release of
the library will include a new default stylesheet that will improve on the current status quo. But for all the reasons just
discussed, Seaborn remains an extremely useful addon.
公平的说,Matplotlib团队也在改进这些问题:近期的版本增加了 plt.style ⼯具(参⻅⾃定义matplotlib:配置和样式单),开始让
Matplotlib更加⽆缝地对接Pandas的数据。2.0版本会使⽤新的默认样式单⽤来改进⽬前的样式问题。但是对于我们刚才讨论的问题来说,
Seaborn依然是⼀个很有⽤的扩展。
Seaborn Versus Matplotlib
对⽐ Matplotlib
Seaborn
Here is an example of a simple random-walk plot in Matplotlib, using its classic plot formatting and colors. We start with
the typical imports:
下⾯的例⼦是⼀个简单的随机趋势数据的例⼦,在Matplotlib使⽤经典的图表样式和颜⾊绘制。先进⾏标准导⼊:
In [1]: import matplotlib.pyplot as plt
plt.style.use('classic')
%matplotlib inline
import numpy as np
import pandas as pd
Now we create some random walk data:
然后创建随机趋势的数据:
创建⼀些随机数据
In [2]: #
rng = np.random.RandomState(0)
x = np.linspace(0, 10, 500)
y = np.cumsum(rng.randn(500, 6), 0)
And do a simple plot:
绘制简单折线图:
使⽤默认样式绘制图表
In [3]: #
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');
Although the result contains all the information we'd like it to convey, it does so in a way that is not all that aesthetically
pleasing, and even looks a bit old-fashioned in the context of 21st-century data visualization.
虽然结果包含了所有我们希望涵盖的信息,但是它展现的形式并不是特别的美观,和21世纪的数据可视化效果⽐较起来甚⾄显得有⼀点⽼
⼟。
Now let's take a look at how it works with Seaborn. As we will see, Seaborn has many of its own high-level plotting
routines, but it can also overwrite Matplotlib's default parameters and in turn get even simple Matplotlib scripts to produce
vastly superior output. We can set the style by calling Seaborn's set() method. By convention, Seaborn is imported as
sns :
现在让我们看⼀看Seaborn的结果。正如我们看到的,Seaborn有很多的⾃⼰的⾼层绘图函数,但是它也覆盖了Matplotlib默认参数并且能
使⽤更简单的Matplotlib代码脚本产⽣复杂的输出结果。我们可以通过调⽤Seaborn的 set() 函数设置Seaborn的样式。按照惯例Seaborn
被载⼊成别名 sns :
In [4]: import seaborn as sns
sns.set()
Now let's rerun the same two lines as before:
现在我们来产⽣同样的折线图:
In [5]: # 所有的代码与上例中的代码⼀样
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');
Ah, much better!
嗯,好看多了。
Exploring Seaborn Plots
探索 Seaborn 图表
The main idea of Seaborn is that it provides high-level commands to create a variety of plot types useful for statistical
data exploration, and even some statistical model fitting.
Seaborn
的主要设计思想是提供⼀套⾼层的接⼝来创建各种各样的统计数据报表,甚⾄与⼀些统计模型适应。
Let's take a look at a few of the datasets and plot types available in Seaborn. Note that all of the following could be done
using raw Matplotlib commands (this is, in fact, what Seaborn does under the hood) but the Seaborn API is much more
convenient.
下⾯让我们看看Seaborn中⼀些数据集和图表类型。请注意所有下⾯介绍到的内容都可以通过Matplotlib(实际上是Seaborn的底层)实
现,但是Seaborn的API⽤起来⽅便多了。
Histograms, KDE, and densities
直⽅图、KDE 和 密度
Often in statistical data visualization, all you want is to plot histograms and joint distributions of variables. We have seen
that this is relatively straightforward in Matplotlib:
通常在统计数据可视化当中,绘制直⽅图和变量的联合分布可能就是你全部的需求。我们已经在Matplotlib中相对直接的展⽰过这种技巧:
译者注:下⾯代码将normed参数改为density。
In [6]: data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])
for col in 'xy':
plt.hist(data[col], density=True, alpha=0.5)
Rather than a histogram, we can get a smooth estimate of the distribution using a kernel density estimation, which
Seaborn does with sns.kdeplot :
相对于直⽅图,我们可以使⽤核密度估计(KDE)来获得⼀个平滑的估计图,在Seaborn中调⽤ sns.kdeplot 得到:
In [7]: for col in 'xy':
sns.kdeplot(data[col], shade=True)
Histograms and KDE can be combined using distplot :
直⽅图和KDE可以使⽤ distplot 组合输出:
In [8]: sns.distplot(data['x'])
sns.distplot(data['y']);
If we pass the full two-dimensional dataset to kdeplot , we will get a two-dimensional visualization of the data:
如果我们将完整的⼆维数据集传递给 kdeplot ,我们会得到数据的⼆维可视化图:
译者注:新版Seaborn的 kdeplot 函数不再⽀持传递⼆维数据,需要拆分成两个参数,因此下⾯的代码改为两个参数的调⽤⽅式。
In [9]: sns.kdeplot(data.x, data.y);
We can see the joint distribution and the marginal distributions together using sns.jointplot . For this plot, we'll set
the style to a white background:
我们可以使⽤ sns.jointplot 函数同时绘制联合分布和边缘分布。下例中,我们将图表背景改为⽩⾊:
In [10]: with sns.axes_style('white'):
sns.jointplot("x", "y", data, kind='kde');
There are other parameters that can be passed to jointplot —for example, we can use a hexagonally based
histogram instead:
我们还可以传递其他的参数到 jointplot ,例如,使⽤六边形联合分布和直⽅图:
In [11]: with sns.axes_style('white'):
sns.jointplot("x", "y", data, kind='hex')
Pair plots
散点图矩阵
When you generalize joint plots to datasets of larger dimensions, you end up with pair plots. This is very useful for
exploring correlations between multidimensional data, when you'd like to plot all pairs of values against each other.
当你将联合分布图推⼴到更多的维度时,你就会获得散点图矩阵。当你希望将所有属性两两组成⼀对来分析多维数据时是⾮常有⽤的。
We'll demo this with the well-known Iris dataset, which lists measurements of petals and sepals of three iris species:
我们使⽤著名的鸢尾花数据集来展⽰散点图矩阵,⾥⾯列出了三种不同种鸢尾花的花瓣和花萼的测量值:
In [12]: iris = sns.load_dataset("iris")
iris.head()
Out[12]:
sepal_length
sepal_width
petal_length
petal_width
species
0
5.1
3.5
1.4
0.2
setosa
1
4.9
3.0
1.4
0.2
setosa
2
4.7
3.2
1.3
0.2
setosa
3
4.6
3.1
1.5
0.2
setosa
4
5.0
3.6
1.4
0.2
setosa
Visualizing the multidimensional relationships among the samples is as easy as calling sns.pairplot :
传递样本数据集调⽤ sns.pairplot 函数可以很容易的展⽰多维数据的关系:
译者注:下⾯代码中的size已经过时,修改为height。
In [13]: sns.pairplot(iris, hue='species', height=2.5);
Faceted histograms
多⾯直⽅图
Sometimes the best way to view data is via histograms of subsets. Seaborn's FacetGrid makes this extremely simple.
We'll take a look at some data that shows the amount that restaurant staff receive in tips based on various indicator data:
有些情况下展⽰数据的最好⽅式通过⼦数据集的直⽅图。Seaborn的 FacetGrid 将它变得⾮常简单。我们⾸先查看⼀些餐厅⼯作⼈员获
得⼩费的数据情况,这是通过不同的指标数据获得的数据集:
译者注:下⾯代码将直接从data⽬录中读取tips.csv⽂件,因为Seaborn已经⽆法从⽹上下载tips数据集。
In [14]: import pandas as pd
tips = pd.read_csv('data/tips.csv')
tips.head()
Out[14]:
total_bill
tip
sex
smoker
day
time
size
0
16.99
1.01
Female
No
Sun
Dinner
2
1
10.34
1.66
Male
No
Sun
Dinner
3
2
21.01
3.50
Male
No
Sun
Dinner
3
3
23.68
3.31
Male
No
Sun
Dinner
2
4
24.59
3.61
Female
No
Sun
Dinner
4
In [15]: tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True)
grid.map(plt.hist, "tip_pct", bins=np.linspace(0, 40, 15));
Factor plots
因⼦图
Factor plots can be useful for this kind of visualization as well. This allows you to view the distribution of a parameter
within bins defined by any other parameter:
因⼦图也可以很好的展现这个数据。它允许你将⼀个参数的分布按照另⼀个参数进⾏分桶再展⽰在图表中:
译者注: factorplot 函数已过时,下⾯代码更新为了 catplot 函数。
In [16]: with sns.axes_style(style='ticks'):
g = sns.catplot("day", "total_bill", "sex", data=tips, kind="box")
g.set_axis_labels("Day", "Total Bill");
Joint distributions
联合分布
Similar to the pairplot we saw earlier, we can use sns.jointplot to show the joint distribution between different
datasets, along with the associated marginal distributions:
类似前⾯的散点图矩阵,我们可以使⽤ sns.jointplot 来展⽰不同数据集中间的联合分布,以及它们的边缘分布情况:
In [17]: with sns.axes_style('white'):
sns.jointplot("total_bill", "tip", data=tips, kind='hex')
The joint plot can even do some automatic kernel density estimation and regression:
联合分布图还可以⾃动进⾏核密度估计以及回归:
In [18]: sns.jointplot("total_bill", "tip", data=tips, kind='reg');
Bar plots
柱状图
Time series can be plotted using sns.factorplot . In the following example, we'll use the Planets data that we first
saw in Aggregation and Grouping:
时间序列可以使⽤ sns.factorplot 进⾏图表绘制。在下例中,我们会使⽤在聚合与分组中使⽤过的⾏星数据:
译者注:同样,下⾯的 factorplot 因为过时被 catplot 取代。
In [19]: planets = sns.load_dataset('planets')
planets.head()
Out[19]:
method
number
orbital_period
mass
distance
year
0
Radial Velocity
1
269.300
7.10
77.40
2006
1
Radial Velocity
1
874.774
2.21
56.95
2008
2
Radial Velocity
1
763.000
2.60
19.84
2011
3
Radial Velocity
1
326.030
19.40
110.62
2007
4
Radial Velocity
1
516.220
10.50
119.47
2009
In [20]: with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=2,
kind="count", color='steelblue')
g.set_xticklabels(step=5)
We can learn more by looking at the method of discovery of each of these planets:
我们还可以使⽤发现这些⾏星的⽅法来更加细致的分析这个数据集:
In [21]: with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=4.0, kind='count',
hue='method', order=range(2001, 2015))
g.set_ylabels('Number of Planets Discovered')
For more information on plotting with Seaborn, see the Seaborn documentation, a tutorial, and the Seaborn gallery.
想获得更多使⽤Seaborn绘制图表的内容,请参考Seaborn在线⽂档、教程以及Seaborn图库。
Example: Exploring Marathon Finishing Times
例⼦:⻢拉松完成时间分析
Here we'll look at using Seaborn to help visualize and understand finishing results from a marathon. I've scraped the data
from sources on the Web, aggregated it and removed any identifying information, and put it on GitHub where it can be
downloaded (if you are interested in using Python for web scraping, I would recommend Web Scraping with Python by
Ryan Mitchell). We will start by downloading the data from the Web, and loading it into Pandas:
下⾯我们来看⼀下使⽤Seaborn分析和可视化⻢拉松完成结果数据的例⼦。作者已经从⽹上将数据爬取了下来,组合了这些数据并且删除
了⾝份信息,数据放在GitHub上⾯提供下载(如果你对使⽤Python进⾏⽹⻚爬取感兴趣,作者推荐Ryan Mitchell写的Python⽹络爬取)。
我们⾸先下载这个数据,然后使⽤Pandas将数据载⼊:
译者注:本仓库notebooks/data⽬录下带有数据⽂件,下⾯的载⼊语句⽬录相应修改。
In [22]: # !curl -O https://raw.githubusercontent.com/jakevdp/marathon-data/master/marathon-data.csv
In [23]: data = pd.read_csv('data/marathon-data.csv')
data.head()
Out[23]:
age
gender
split
final
0
33
M
01:05:38
02:08:51
1
32
M
01:06:26
02:09:28
2
31
M
01:06:49
02:10:42
3
38
M
01:06:16
02:13:45
4
31
M
01:06:32
02:13:59
By default, Pandas loaded the time columns as Python strings (type object ); we can see this by looking at the
dtypes attribute of the DataFrame:
默认情况下,Pandas将时间列读取载⼊成Python字符串(Pandas中的 object 类型);我们可以通过查看DataFrame的dtypes属性知
道:
In [24]: data.dtypes
Out[24]: age
int64
gender
object
split
object
final
object
dtype: object
Let's fix this by providing a converter for the times:
让我们提供⼀个转换器函数来修正这⼀列:
In [25]: import datetime
def convert_time(s):
h, m, s = map(int, s.split(':'))
return datetime.timedelta(hours=h, minutes=m, seconds=s)
data = pd.read_csv('data/marathon-data.csv',
converters={'split':convert_time, 'final':convert_time})
data.head()
Out[25]:
age
gender
split
final
0
33
M
01:05:38
02:08:51
1
32
M
01:06:26
02:09:28
2
31
M
01:06:49
02:10:42
3
38
M
01:06:16
02:13:45
4
31
M
01:06:32
02:13:59
In [26]: data.dtypes
Out[26]: age
int64
gender
object
split
timedelta64[ns]
final
timedelta64[ns]
dtype: object
That looks much better. For the purpose of our Seaborn plotting utilities, let's next add columns that give the times in
seconds:
这样看起来就正常了。为了Seaborn绘图⼯具能正常⼯作,为这个数据集添加上两列,将时间转为秒数:
In [27]: data['split_sec'] = data['split'].astype(int) / 1E9
data['final_sec'] = data['final'].astype(int) / 1E9
data.head()
Out[27]:
age
gender
split
final
split_sec
final_sec
0
33
M
01:05:38
02:08:51
3938.0
7731.0
1
32
M
01:06:26
02:09:28
3986.0
7768.0
2
31
M
01:06:49
02:10:42
4009.0
7842.0
3
38
M
01:06:16
02:13:45
3976.0
8025.0
4
31
M
01:06:32
02:13:59
3992.0
8039.0
To get an idea of what the data looks like, we can plot a jointplot over the data:
要初步查看⽬前数据的情况,我们可以在数据集上绘制⼀个联合分布图:
In [28]: with sns.axes_style('white'):
g = sns.jointplot("split_sec", "final_sec", data, kind='hex')
g.ax_joint.plot(np.linspace(4000, 16000),
np.linspace(8000, 32000), ':k')
The dotted line shows where someone's time would lie if they ran the marathon at a perfectly steady pace. The fact that
the distribution lies above this indicates (as you might expect) that most people slow down over the course of the
marathon. If you have run competitively, you'll know that those who do the opposite—run faster during the second half of
the race—are said to have "negative-split" the race.
上图中的点线表⽰,如果⼀个⼈在⼀场⻢拉松⽐赛中保持了⼀个完美的匀速,那么他的成绩将位于这条线上。事实上这个分布都处于这条
线上的原因,也是显⽽易⻅的,⼤多数⼈随着⻢拉松的进程都会慢下来。如果你有参加过竞技⻢拉松⽐赛,你可能就会了解那些不符合这
个趋势的选⼿,即后半程跑的更快的参赛者,被称为后半程加速。
Let's create another column in the data, the split fraction, which measures the degree to which each runner negativesplits or positive-splits the race:
让我们再创建⼀个列,⽤来衡量每个选⼿是后半程加速还是前半程跑的快:
In [29]: data['split_frac'] = 1 - 2 * data['split_sec'] / data['final_sec']
data.head()
Out[29]:
age
gender
split
final
split_sec
final_sec
split_frac
0
33
M
01:05:38
02:08:51
3938.0
7731.0
-0.018756
1
32
M
01:06:26
02:09:28
3986.0
7768.0
-0.026262
2
31
M
01:06:49
02:10:42
4009.0
7842.0
-0.022443
3
38
M
01:06:16
02:13:45
3976.0
8025.0
0.009097
4
31
M
01:06:32
02:13:59
3992.0
8039.0
0.006842
Where this split difference is less than zero, the person negative-split the race by that fraction. Let's do a distribution plot
of this split fraction:
当 split_frac 列为负数时,该选⼿是后半程加速。让我们绘制这⼀列的分布情况:
In [30]: sns.distplot(data['split_frac'], kde=False);
plt.axvline(0, color="k", linestyle="--");
In [31]: sum(data.split_frac < 0)
Out[31]: 251
Out of nearly 40,000 participants, there were only 250 people who negative-split their marathon.
将近40000名选⼿中,仅有250⼈是使⽤后半程加速完成⻢拉松⽐赛的。
Let's see whether there is any correlation between this split fraction and other variables. We'll do this using a
pairgrid , which draws plots of all these correlations:
让我们观察⼀下这个半程加速分布列和其他列的相关性。你应该也知道应该使⽤ pairgrid 绘制散点图矩阵了:
In [32]: g = sns.PairGrid(data, vars=['age', 'split_sec', 'final_sec', 'split_frac'],
hue='gender', palette='RdBu_r')
g.map(plt.scatter, alpha=0.8)
g.add_legend();
It looks like the split fraction does not correlate particularly with age, but does correlate with the final time: faster runners
tend to have closer to even splits on their marathon time. (We see here that Seaborn is no panacea for Matplotlib's ills
when it comes to plot styles: in particular, the x-axis labels overlap. Because the output is a simple Matplotlib plot,
however, the methods in Customizing Ticks can be used to adjust such things if desired.)
从上图得知,半程加速分布似乎与年龄没有特别⼤的相关性,但是确实和最终完成时间有相关性:成绩越好的选⼿越善于平均分配前后半
程的速度和时间。
The difference between men and women here is interesting. Let's look at the histogram of split fractions for these two
groups:
这⾥⽐较有趣的是性别的差异。让我们将这两个组的半程加速分布数据⽤直⽅图展⽰出来:
In [33]: sns.kdeplot(data.split_frac[data.gender=='M'], label='men', shade=True)
sns.kdeplot(data.split_frac[data.gender=='W'], label='women', shade=True)
plt.xlabel('split_frac');
The interesting thing here is that there are many more men than women who are running close to an even split! This
almost looks like some kind of bimodal distribution among the men and women. Let's see if we can suss-out what's going
on by looking at the distributions as a function of age.
上图中有趣的地⽅是男性前后半程均匀速度和时间的数量⽐⼥性多很多。这⼏乎有点像⼀个双峰分布的形状了。让我们试着探寻⾥⾯的原
因。
A nice way to compare distributions is to use a violin plot
⽐较两个分布的好⽅法是使⽤⼩提琴图
In [34]: sns.violinplot("gender", "split_frac", data=data,
palette=["lightblue", "lightpink"]);
This is yet another way to compare the distributions between men and women.
这也是⼀个⽐较男性和⼥性分布情况的⽅式。
Let's look a little deeper, and compare these violin plots as a function of age. We'll start by creating a new column in the
array that specifies the decade of age that each person is in:
让我们继续深⼊,根据年龄数据⽐较这些⼩提琴图。我们再创建⼀个列来表⽰每个选⼿的年龄段:
In [35]: data['age_dec'] = data.age.map(lambda age: 10 * (age // 10)) # 10-20/20-30
data.head()
等
Out[35]:
age
gender
split
final
split_sec
final_sec
split_frac
age_dec
0
33
M
01:05:38
02:08:51
3938.0
7731.0
-0.018756
30
1
32
M
01:06:26
02:09:28
3986.0
7768.0
-0.026262
30
2
31
M
01:06:49
02:10:42
4009.0
7842.0
-0.022443
30
3
38
M
01:06:16
02:13:45
3976.0
8025.0
0.009097
30
4
31
M
01:06:32
02:13:59
3992.0
8039.0
0.006842
30
In [36]: men = (data.gender == 'M')
women = (data.gender == 'W')
with sns.axes_style(style=None):
sns.violinplot("age_dec", "split_frac", hue="gender", data=data,
split=True, inner="quartile",
palette=["lightblue", "lightpink"]);
Looking at this, we can see where the distributions of men and women differ: the split distributions of men in their 20s to
50s show a pronounced over-density toward lower splits when compared to women of the same age (or of any age, for
that matter).
再看上图,我们可以发现男性和⼥性分布情况的区别:男性年龄处于20到50之间的时候,其半程平均程度的分布均⽐同年龄段⼥性的分布
要更密集。
Also surprisingly, the 80-year-old women seem to outperform everyone in terms of their split time. This is probably due to
the fact that we're estimating the distribution from small numbers, as there are only a handful of runners in that range:
令我们惊讶的是,80岁以上的⼥性似乎在半程平均程度上优于所有年龄段和性别的分布。这也许是由于这个分布是来⾃⼀个很⼩的数据样
本,因为这个年龄段的参加⼈数是很稀少的:
In [37]: (data.age > 80).sum()
Out[37]: 7
Back to the men with negative splits: who are these runners? Does this split fraction correlate with finishing quickly? We
can plot this very easily. We'll use regplot , which will automatically fit a linear regression to the data:
回到后半程加速的选⼿⾝上:它们是谁?是否后半程加速与⽐赛成绩有相关性?我们可以很容易的绘制这张图。调⽤ regplot 函数,它
能⾃动的为数据找到⼀个线性回归预测:
In [38]: g = sns.lmplot('final_sec', 'split_frac', col='gender', data=data,
markers=".", scatter_kws=dict(color='c'))
g.map(plt.axhline, y=0.1, color="k", ls=":");
Apparently the people with fast splits are the elite runners who are finishing within ~15,000 seconds, or about 4 hours.
People slower than that are much less likely to have a fast second split.
很明显了,成绩优秀的选⼿或者叫精英选⼿,是那些能在约15000秒或4个⼩时内完成的⼈。低于这个成绩的选⼿很少能在后半程加速完成
⽐赛。
<
使⽤Basemap创建地理位置图表 | ⽬录 | 更多资源 >
Open in Colab
<
使⽤Seaborn进⾏可视化 | ⽬录 | 机器学习 >
Further Resources
更多资源
Matplotlib Resources
Matplotlib
资源
A single chapter in a book can never hope to cover all the available features and plot types available in Matplotlib. As with
other packages we've seen, liberal use of IPython's tab-completion and help functions (see Help and Documentation in
IPython) can be very helpful when exploring Matplotlib's API. In addition, Matplotlib’s online documentation can be a
helpful reference. See in particular the Matplotlib gallery linked on that page: it shows thumbnails of hundreds of different
plot types, each one linked to a page with the Python code snippet used to generate it. In this way, you can visually
inspect and learn about a wide range of different plotting styles and visualization techniques.
本书中短短的⼀章内容不可能涵盖Matplotlib中所有的特性和图表类型。就像我们之前介绍到的其他⼯具⼀样,使⽤IPython的TAB⾃动补全
以及帮助功能(参⻅IPython的帮助和⽂档)对于学习Matplotlib的API是⾮常有帮助的。Matplotlib在线⽂档也是⼀个很有帮助的参考内容。
建议去浏览Matplotlib画廊:上⾯展⽰了上百个不同种类图表的缩略图,每个图都有⼀个超链接能导航到创建这个图表的Python代码⽚段⻚
⾯。使⽤上述⽅法可以直观的浏览和学习许多不同类型的图表样式和可视化技巧。
For a book-length treatment of Matplotlib, I would recommend Interactive Applications Using Matplotlib, written by
Matplotlib core developer Ben Root.
如果你需要阅读书籍,作者建议使⽤Matplotlib交互式应⽤,作者是Matplotlib的核⼼开发者Ben Root。
Other Python Graphics Libraries
其他Python图像包
Although Matplotlib is the most prominent Python visualization library, there are other more modern tools that are worth
exploring as well. I'll mention a few of them briefly here:
Bokeh is a JavaScript visualization library with a Python frontend that creates highly interactive visualizations
capable of handling very large and/or streaming datasets. The Python front-end outputs a JSON data structure that
can be interpreted by the Bokeh JS engine.
Plotly is the eponymous open source product of the Plotly company, and is similar in spirit to Bokeh. Because Plotly
is the main product of a startup, it is receiving a high level of development effort. Use of the library is entirely free.
Vispy is an actively developed project focused on dynamic visualizations of very large datasets. Because it is built to
target OpenGL and make use of efficient graphics processors in your computer, it is able to render some quite large
and stunning visualizations.
Vega and Vega-Lite are declarative graphics representations, and are the product of years of research into the
fundamental language of data visualization. The reference rendering implementation is JavaScript, but the API is
language agnostic. There is a Python API under development in the Altair package. Though as of summer 2016 it's
not yet fully mature, I'm quite excited for the possibilities of this project to provide a common reference point for
visualization in Python and other languages.
虽然Matplotlib在Python可视化库中是占统治地位的,但是还有很多其他更多现代的⼯具值得了解和学习。下⾯简单的介绍⼀下它们:
Bokeh是⼀个在前端使⽤Python的JavaScript可视化库,能够创建⾼度交互的可视化图表处理⼤量和流式的数据集。Python前端会产
⽣JSON数据结构然后交给Bokeh JS引擎进⾏解析处理。
Plotly是Plotly公司开发维护的⼀套同名产品,它的理念与Bokeh类似。因为Plotly是这间初创公司的主要产品,所以它正处于开发密集
时期,使⽤这个库是完全免费的。
Vispy是⼀个聚焦于⼤数据动态可视化的活跃项⽬。因为它的⽬标是⽀持OpenGL,并且有效地使⽤计算机的显卡资源,所以它能渲染
⼀些⾮常巨⼤和炫⽬的可视化效果。
Vega和Vega-Lite是陈述式的图形表达,提供了⼀个多年研究的数据可视化基础语⾔产品。渲染实现的基础是JavaScript,但是它的
API是语⾔⽆关的。Vega有⼀个正在开发中的Python API叫做Altair包。虽然在2016年夏天这个包还未完全成熟,但作者对于项⽬在为
Python和其他语⾔提供通⽤的可视化功能取得的进展感到兴奋。
The visualization space in the Python community is very dynamic, and I fully expect this list to be out of date as soon as it
is published.
Python
社区中数据可视化部分变化很快,作者估计上述的列表可能在本书出版的时候就已经显得过时了。
Keep an eye out for what's coming in the future!
希望读者能保持对这个领域未来的关注。
<
使⽤Seaborn进⾏可视化 | ⽬录 | 机器学习 >
<
更多资源 | ⽬录 | 什么是机器学习? >
Machine Learning
机器学习
In many ways, machine learning is the primary means by which data science manifests itself to the broader world.
Machine learning is where these computational and algorithmic skills of data science meet the statistical thinking of data
science, and the result is a collection of approaches to inference and data exploration that are not about effective theory
so much as effective computation.
在很多情况下,机器学习是数据科学本⾝以及更⼴泛领域中的主要⽅法。机器学习是数据科学中计算及算法和统计思维相结合的产物,结
果得到的是⼀整套的推理⽅法和数据分析⼯具。
The term "machine learning" is sometimes thrown around as if it is some kind of magic pill: apply machine learning to
your data, and all your problems will be solved! As you might expect, the reality is rarely this simple. While these methods
can be incredibly powerful, to be effective they must be approached with a firm grasp of the strengths and weaknesses of
each method, as well as a grasp of general concepts such as bias and variance, overfitting and underfitting, and more.
术语“机器学习”有时候会被滥⽤就好像这是⼀剂灵丹妙药⼀样:在你的数据上使⽤机器学习吧,你的所有问题都会得到解决。虽然这些⽅
法可能⾮常强⼤,但是你必须掌握每种⽅法的优缺点才能令它们更加有效,你需要掌握偏差和⽅差的基本概念,以及过拟合和⽋拟合等
等。
This chapter will dive into practical aspects of machine learning, primarily using Python's Scikit-Learn package. This is not
meant to be a comprehensive introduction to the field of machine learning; that is a large subject and necessitates a more
technical approach than we take here. Nor is it meant to be a comprehensive manual for the use of the Scikit-Learn
package (for this, you can refer to the resources listed in Further Machine Learning Resources). Rather, the goals of this
chapter are:
To introduce the fundamental vocabulary and concepts of machine learning.
To introduce the Scikit-Learn API and show some examples of its use.
To take a deeper dive into the details of several of the most important machine learning approaches, and develop an
intuition into how they work and when and where they are applicable.
本章会从实践的⻆度深⼊的介绍机器学习,主要使⽤Python的Scikit-Learn包。这并不是⼀个机器学习领域完整的介绍;因为这⼀课题过于
庞⼤复杂,需要⽐本书更⾼级的技术书籍才能阐述清楚。本章甚⾄不是⼀个使⽤Scikit-Learn包的完整⼿册(为此,你需要在更多机器学习
资源中列出的内容)。本章的⽬标是:
对机器学习基本术语和概念的介绍。
Scikit-Learn包API的简单介绍以及使⽤例⼦。
对最重要的机器学习⽅法进⾏深⼊介绍,帮你建⽴它们⼯作原理的概念,并对它们的应⽤范围进⾏了解。
Much of this material is drawn from the Scikit-Learn tutorials and workshops I have given on several occasions at PyCon,
SciPy, PyData, and other conferences. Any clarity in the following pages is likely due to the many workshop participants
and co-instructors who have given me valuable feedback on this material over the years!
本章很多的材料都是从作者多次在PyCon、SciPy、PyData和其他论坛中对Scikit-Learn的教程和⼯坊中精选出来的。接下来的各⼩节内容
都得到了这些论坛参与者和同事的反馈及帮助。
Finally, if you are seeking a more comprehensive or technical treatment of any of these subjects, I've listed several
resources and references in Further Machine Learning Resources.
最后,如果你在寻找更全⾯和深⼊的资料的话,在更多机器学习资源⼀节中列出了⼀些你需要的资源。
<
更多资源 | ⽬录 | 什么是机器学习? >
<
机器学习 | ⽬录 | Scikit-Learn简介 >
Open in Colab
What Is Machine Learning?
什么是机器学习?
Before we take a look at the details of various machine learning methods, let's start by looking at what machine learning
is, and what it isn't. Machine learning is often categorized as a subfield of artificial intelligence, but I find that
categorization can often be misleading at first brush. The study of machine learning certainly arose from research in this
context, but in the data science application of machine learning methods, it's more helpful to think of machine learning as
a means of building models of data.
在我们开始学习机器学习⽅法的细节之前,让我们先来了解机器学习是什么以及不是什么。机器学习经常被归为⼈⼯智能的⼀个⼦领域,
但作者发现这种分类⽅式常常⼀开始就导致了误解。对机器学习的研究肯定是在这个环境中发展出来的,但是机器学习⽅法在数据科学应
⽤中,它更适合被看成是数据的构造模型。
Fundamentally, machine learning involves building mathematical models to help understand data. "Learning" enters the
fray when we give these models tunable parameters that can be adapted to observed data; in this way the program can
be considered to be "learning" from the data. Once these models have been fit to previously seen data, they can be used
to predict and understand aspects of newly observed data. I'll leave to the reader the more philosophical digression
regarding the extent to which this type of mathematical, model-based "learning" is similar to the "learning" exhibited by
the human brain.
机器学习基本上就是关于构建数学模型来帮助我们理解数据。当我们为这些模型提供了可调整的参数时,“学习”能让我们从观察到的数据
中调整这些参数。也就是说,这个过程可以被认为我们从数据中“学习”。⼀旦这些模型已经适应(拟合)了观察到的数据之后,它们就可
以⽤来预测和理解新的数据。作者把这个问题的哲学思考留给读者,基于模型的“学习”确实与⼈脑展⽰的“学习”类似。
Understanding the problem setting in machine learning is essential to using these tools effectively, and so we will start
with some broad categorizations of the types of approaches we'll discuss here.
理解机器学习中的各种概念是有效使⽤这些⼯具的基础,因此我们⾸先介绍机器学习的分类以及⽅法的类型。
Categories of Machine Learning
机器学习分类
At the most fundamental level, machine learning can be categorized into two main types: supervised learning and
unsupervised learning.
在最基础的层次上,机器学习可以被分为两⼤类:有监督学习和⽆监督学习。
Supervised learning involves somehow modeling the relationship between measured features of data and some label
associated with the data; once this model is determined, it can be used to apply labels to new, unknown data. This is
further subdivided into classification tasks and regression tasks: in classification, the labels are discrete categories, while
in regression, the labels are continuous quantities. We will see examples of both types of supervised learning in the
following section.
有监督学习指的是在除了数据本⾝外,我们还拥有对数据进⾏的标记,有监督学习就是要建⽴两者之间的联系模型,然后这个模型就可以
应⽤在新的数据上进⾏标记。它可以进⼀步分为分类和回归任务:在分类中,标记的是离散的分组,⽽在回归中,标记的是连续的量。我
们在后续章节中会看到这两种有监督学习的例⼦。
Unsupervised learning involves modeling the features of a dataset without reference to any label, and is often described
as "letting the dataset speak for itself." These models include tasks such as clustering and dimensionality reduction.
Clustering algorithms identify distinct groups of data, while dimensionality reduction algorithms search for more succinct
representations of the data. We will see examples of both types of unsupervised learning in the following section.
⽆监督学习是从没有标记的数据中建⽴模型,它常被描述为“让数据集⾃⼰说话”。这样的模型包括聚类和降维。聚类算法能识别数据中的
分组,⽽降维算法寻找数据更简洁的表达形式。我们在后续章节中会看到这两种⽆监督学习的例⼦。
In addition, there are so-called semi-supervised learning methods, which falls somewhere between supervised learning
and unsupervised learning. Semi-supervised learning methods are often useful when only incomplete labels are
available.
除此之外,还有⼀种被成为半监督学习的⽅法,介于有监督学习和⽆监督学习之间。半监督学习⽅法经常应⽤在不完整的数据标记的场合
中。
Qualitative Examples of Machine Learning Applications
机器学习应⽤的定性例⼦
To make these ideas more concrete, let's take a look at a few very simple examples of a machine learning task. These
examples are meant to give an intuitive, non-quantitative overview of the types of machine learning tasks we will be
looking at in this chapter. In later sections, we will go into more depth regarding the particular models and how they are
used. For a preview of these more technical aspects, you can find the Python source that generates the following figures
in the Appendix: Figure Code.
要更具体的说明这些内容,我们来看⼀些⾮常简单的机器学习任务例⼦。这些例⼦为了给读者提供⼀个直观的,⾮定量的机器学习任务的
概要介绍。在后续章节中,我们会深⼊介绍每⼀个模型以及它们是如何使⽤的。产⽣下⾯的图像的代码可以在附录:产⽣图像的代码中找
到。
Classification: Predicting discrete labels
分类:预测离散的标签
We will first take a look at a simple classification task, in which you are given a set of labeled points and want to use
these to classify some unlabeled points.
Imagine that we have the data shown in this figure:
我们⾸先看⼀个简单的分类任务,你有⼀组标记过的点,然后你使⽤这些数据来标记新的未标记过的数据点。我们有下图展⽰的数据:
附录中⽣成图像的代码
Here we have two-dimensional data: that is, we have two features for each point, represented by the (x,y) positions of the
points on the plane. In addition, we have one of two class labels for each point, here represented by the colors of the
points. From these features and labels, we would like to create a model that will let us decide whether a new point should
be labeled "blue" or "red."
这⾥我们有⼆维的数据:即这⾥⾯的每个点我们都有两个特征,使⽤平⾯中的(x,y)位置表⽰。除此之外,我们对每个点都有⼀个标记,标
记⼀共有两种,上图中使⽤了颜⾊进⾏区分。使⽤这些特征和标记,我们可以建⽴⼀个模型,然后我们就可以对⼀个新的数据点进⾏标
记,判断它属于“蓝⾊”还是“红⾊”。
There are a number of possible models for such a classification task, but here we will use an extremely simple one. We
will make the assumption that the two groups can be separated by drawing a straight line through the plane between
them, such that points on each side of the line fall in the same group. Here the model is a quantitative version of the
statement "a straight line separates the classes", while the model parameters are the particular numbers describing the
location and orientation of that line for our data. The optimal values for these model parameters are learned from the data
(this is the "learning" in machine learning), which is often called training the model.
对于这个分类任务来说可以有很多可能的模型,但是我们会使⽤⼀个特别简单的模型。我们假设这两组数据点可以使⽤⼀条平⾯上的直线
进⾏区分,直线两边分别属于两个不同的组。这⾥的模型是“⼀条分类直线”说法的定量版本,⽽模型中的参数就是⽤来描述直线位置和⽅
向的特殊数字。优化后的模型参数值是从数据中学习得到的,这个学习过程我们通常成为训练模型。
The following figure shows a visual representation of what the trained model looks like for this data:
下⾯展⽰了⼀个训练好的模型的可视化图像:
附录中⽣成图像的代码
Now that this model has been trained, it can be generalized to new, unlabeled data. In other words, we can take a new
set of data, draw this model line through it, and assign labels to the new points based on this model. This stage is usually
called prediction. See the following figure:
当模型训练好之后,它就能泛化到新的未标记的数据上。换⼀种说法是,我们可以取⼀组新的数据,将模型的直线画上去穿过它们,然后
给新的数据点定义标签。这个阶段通常被称为预测。参⻅下⾯的图:
附录中⽣成图像的代码
This is the basic idea of a classification task in machine learning, where "classification" indicates that the data has
discrete class labels. At first glance this may look fairly trivial: it would be relatively easy to simply look at this data and
draw such a discriminatory line to accomplish this classification. A benefit of the machine learning approach, however, is
that it can generalize to much larger datasets in many more dimensions.
上⾯就是机器学习中分类任务的基本概念,这⾥的分类表明数据具有离散的类别标签。第⼀眼看上去这个任务显得很琐碎:观察数据并画
出这样⼀条分类的直线显得相对来说很容易。但是机器学习⽅法的优势在于,它可以泛化到⾮常⼤的数据集上,以及更多的维度上。
For example, this is similar to the task of automated spam detection for email; in this case, we might use the following
features and labels:
feature 1, feature 2, etc. → normalized counts of important words or phrases ("Viagra", "Nigerian prince", etc.)
label → "spam" or "not spam"
例如,类似⾃动垃圾电⼦邮件识别,在这种情况下,我们可能会⽤到下⾯的特征和标签:
特征1、特征2等 正则化后的重要单词或短语的计数(“伟哥”,“尼⽇利亚王⼦”等)
标签 “垃圾邮件”或“⾮垃圾邮件”
→
→
For the training set, these labels might be determined by individual inspection of a small representative sample of emails;
for the remaining emails, the label would be determined using the model. For a suitably trained classification algorithm
with enough well-constructed features (typically thousands or millions of words or phrases), this type of approach can be
very effective. We will see an example of such text-based classification in In Depth: Naive Bayes Classification.
对于这个训练集来说,这些标签可以通过检查⼀部分电⼦邮件的典型样本来获得,对于剩余的电⼦邮件,标签可以使⽤模型得到。对于⼀
个良好训练的分类算法⽽⾔,它包括⾜够多的特征(上千或上百万的单词或短语),这样的⽅法会⾮常有效。我们会在深⼊:朴素⻉叶斯
分类⼀节中看到⼀个⽂本分类的例⼦。
Some important classification algorithms that we will discuss in more detail are Gaussian naive Bayes (see In Depth:
Naive Bayes Classification), support vector machines (see In-Depth: Support Vector Machines), and random forest
classification (see In-Depth: Decision Trees and Random Forests).
我们后续会讨论到的⼀些重要的分类算法包括⾼斯朴素⻉叶斯(参⻅深⼊:朴素⻉叶斯分类),⽀持向量机(参⻅深⼊:⽀持向量机)和
随机森林分类(参⻅深⼊:决策树和随机森林)。
Regression: Predicting continuous labels
回归:预测连续标签
In contrast with the discrete labels of a classification algorithm, we will next look at a simple regression task in which the
labels are continuous quantities.
对⽐离散标签分类算法,我们下⾯来看⼀个简单的回归任务,它的标签是⼀个连续的数量。
Consider the data shown in the following figure, which consists of a set of points each with a continuous label:
考虑如下图展⽰的数据,包含着⼀组的数据点每⼀个都有⼀个连续的标签:
附录中⽣成图像的代码
As with the classification example, we have two-dimensional data: that is, there are two features describing each data
point. The color of each point represents the continuous label for that point.
就像分类例⼦中那样,我们有着⼆维的数据:即每个数据点都有两个特征。每个点的颜⾊代表这这个点的连续标签。
There are a number of possible regression models we might use for this type of data, but here we will use a simple linear
regression to predict the points. This simple linear regression model assumes that if we treat the label as a third spatial
dimension, we can fit a plane to the data. This is a higher-level generalization of the well-known problem of fitting a line to
data with two coordinates.
对于这个数据集来说,可以有很多种可能的回归模型,但是这⾥我们会使⽤⼀种简单的线性回归来预测数据点。这个简单的线性回归模型
假设我们将数据标签作为第三个空间维度,我们可以在上⾯使⽤⼀个平⾯来拟合数据。这是在两个坐标中使⽤⼀根直线来拟合数据的泛化
版本。
We can visualize this setup as shown in the following figure:
可以使⽤下图可视化这个设置:
附录中⽣成图像的代码
Notice that the feature 1-feature 2 plane here is the same as in the two-dimensional plot from before; in this case,
however, we have represented the labels by both color and three-dimensional axis position. From this view, it seems
reasonable that fitting a plane through this three-dimensional data would allow us to predict the expected label for any set
of input parameters. Returning to the two-dimensional projection, when we fit such a plane we get the result shown in the
following figure:
注意上图中的特征1 - 特征2平⾯与前⾯⼆维图中数据点是⼀致的;我们使⽤了颜⾊以及三维坐标表⽰数据点的标签。从上图中我们可以看
到,通过这个平⾯可以让我们对任意输⼊的数据点参数进⾏标签的预测。返回到⼆维投射,当我们拟合了这个平⾯我们会得到下图的结
果:
附录中⽣成图像的代码
This plane of fit gives us what we need to predict labels for new points. Visually, we find the results shown in the following
figure:
拟合得到的平⾯能为我们提供预测新数据点标签的能⼒。下⾯的图像展⽰了预测的结果:
附录中⽣成图像的代码
As with the classification example, this may seem rather trivial in a low number of dimensions. But the power of these
methods is that they can be straightforwardly applied and evaluated in the case of data with many, many features.
同样的,这个⽅法在维度较少时显得很普通。但是当数据的特征很多时,这个⽅法的威⼒就显现出来了。
For example, this is similar to the task of computing the distance to galaxies observed through a telescope—in this case,
we might use the following features and labels:
feature 1, feature 2, etc. → brightness of each galaxy at one of several wave lengths or colors
label → distance or redshift of the galaxy
例如,类似通过望远镜计算星系之间距离任务时,我们会使⽤下⾯的特征和标签:
特征1、特征2等 每个星系在不同波⻓或颜⾊范围上的亮度值
标签 星系的距离或红移
→
→
The distances for a small number of these galaxies might be determined through an independent set of (typically more
expensive) observations. Distances to remaining galaxies could then be estimated using a suitable regression model,
without the need to employ the more expensive observation across the entire set. In astronomy circles, this is known as
the "photometric redshift" problem.
少量的星系距离可以通过独⽴的观测⽅式(通常更加昂贵)来获得。剩余的星系距离可以使⽤合适的回归模型进⾏估算,避免了在所有星
系上使⽤昂贵观测⽅法的需要。在天⽂学领域,这被称为光度红移问题。
Some important regression algorithms that we will discuss are linear regression (see In Depth: Linear Regression),
support vector machines (see In-Depth: Support Vector Machines), and random forest regression (see In-Depth: Decision
Trees and Random Forests).
我们还会介绍其他⼀些重要的回归算法,包括线性回归(参⻅深⼊:线性回归),⽀持向量机(参⻅深⼊:⽀持向量机)和随机森林回归
(参⻅深⼊:决策树和随机森林)。
Clustering: Inferring labels on unlabeled data
聚类:在未标记的数据上推断标签
The classification and regression illustrations we just looked at are examples of supervised learning algorithms, in which
we are trying to build a model that will predict labels for new data. Unsupervised learning involves models that describe
data without reference to any known labels.
上⾯介绍的分类和回归为我们展⽰了使⽤有监督学习算法的例⼦,我们会从数据中学习得到⼀个模型然后使⽤它预测新数据的标签。⽆监
督学习⽤来描述数据的模型是从没有任何已知标签的数据中获得的。
One common case of unsupervised learning is "clustering," in which data is automatically assigned to some number of
discrete groups. For example, we might have some two-dimensional data like that shown in the following figure:
最常⻅的⽆监督学习场景是“聚类”,其中的数据⾃动组合成⼀些离散的分组。例如下图中展⽰的⼆维数据:
附录中⽣成图像的代码
By eye, it is clear that each of these points is part of a distinct group. Given this input, a clustering model will use the
intrinsic structure of the data to determine which points are related. Using the very fast and intuitive k-means algorithm
(see In Depth: K-Means Clustering), we find the clusters shown in the following figure:
⾁眼观察可以知道很显然这些数据点是不同分组的组成部分。对于这个输⼊来说,⼀个聚类模型会使⽤输⼊数据的内在结构来找到哪些点
是关联的。使⽤下⾯快速直观的k均值算法(参⻅深⼊:k均值聚类),我们会发现如下如的聚类:
附录中⽣成图像的代码
k-means fits a model consisting of k cluster centers; the optimal centers are assumed to be those that minimize the
distance of each point from its assigned center. Again, this might seem like a trivial exercise in two dimensions, but as our
data becomes larger and more complex, such clustering algorithms can be employed to extract useful information from
the dataset.
均值会适应训练出⼀个包括k个聚类中⼼点的模型;优化后的中⼼点应该是属于这个聚类群的所有点距离之和最⼩的点。还是需要说明的
是在⼆维的情况下,这看起来有点平淡⽆奇,但是当我们数据变得更⼤更复杂时,这种聚类算法可以⽤来从数据集中提取出有⽤的信息。
k
We will discuss the k-means algorithm in more depth in In Depth: K-Means Clustering. Other important clustering
algorithms include Gaussian mixture models (See In Depth: Gaussian Mixture Models) and spectral clustering (See
Scikit-Learn's clustering documentation).
我们会在深⼊:k均值聚类⼀节中深⼊讨论k均值算法。其他重要的聚类算法包括⾼斯混合模型(参⻅深⼊:⾼斯混合模型)和谱聚类(参
⻅Scikit-Learn聚类在线⽂档)。
Dimensionality reduction: Inferring structure of unlabeled data
降维:推断⽆标记数据的结构
Dimensionality reduction is another example of an unsupervised algorithm, in which labels or other information are
inferred from the structure of the dataset itself. Dimensionality reduction is a bit more abstract than the examples we
looked at before, but generally it seeks to pull out some low-dimensional representation of data that in some way
preserves relevant qualities of the full dataset. Different dimensionality reduction routines measure these relevant
qualities in different ways, as we will see in In-Depth: Manifold Learning.
降维是另⼀个⽆监督算法的例⼦,它能从数据集本⾝的结构推断标签或其他的信息。降维的例⼦⽐起前⾯那些算法的例⼦稍微复杂⼀些,
总的来说,降维通过⽤更少维度的数据表达但是却保留了完整数据集的相关关键信息。不同的降维算法从不同⽅⾯衡量这些相关信息,就
像我们会在深⼊:流形学习中看到的那样。
As an example of this, consider the data shown in the following figure:
使⽤下图展⽰的数据作为例⼦:
附录中产⽣图像的代码
Visually, it is clear that there is some structure in this data: it is drawn from a one-dimensional line that is arranged in a
spiral within this two-dimensional space. In a sense, you could say that this data is "intrinsically" only one dimensional,
though this one-dimensional data is embedded in higher-dimensional space. A suitable dimensionality reduction model in
this case would be sensitive to this nonlinear embedded structure, and be able to pull out this lower-dimensionality
representation.
从图上很容易看出数据有⼀些内在的结构:数据是由⼀维的线卷曲成螺旋状的⼆维形状。或者直觉上你可以认为数据本质上是⼀维的,不
过是嵌⼊在⼀个更⾼维度的空间中。⼀个合适的降维模型可以在这个情况下感知这种⾮线性的内嵌结构,并且能够将其低维度的数据表现
⽅式提取出来。
The following figure shows a visualization of the results of the Isomap algorithm, a manifold learning algorithm that does
exactly this:
下⾯展⽰了使⽤Isomap算法的可视化结果,这是⼀种适合该应⽤场景的流形学习算法:
附录中⽣成图像的代码
Notice that the colors (which represent the extracted one-dimensional latent variable) change uniformly along the spiral,
which indicates that the algorithm did in fact detect the structure we saw by eye. As with the previous examples, the
power of dimensionality reduction algorithms becomes clearer in higher-dimensional cases. For example, we might wish
to visualize important relationships within a dataset that has 100 or 1,000 features. Visualizing 1,000-dimensional data is
a challenge, and one way we can make this more manageable is to use a dimensionality reduction technique to reduce
the data to two or three dimensions.
注意到上图中的颜⾊(代表着提取出来的⼀维隐变量)是沿着螺旋线均匀变化的,这表明算法确实能够检测到我们⾁眼观察到的结构。降
维算法的威⼒同样可以在更⾼维度的数据中更好的展现出来。例如,我们希望将具有100或1000个特征的数据集的重要关联关系在图中可
视化出来,可视化1000维度的数据是⾮常具有挑战性的,我们可以通过降维技术将数据维度减少到⼆维或三维,这就很容易实现可视化
了。
Some important dimensionality reduction algorithms that we will discuss are principal component analysis (see In Depth:
Principal Component Analysis) and various manifold learning algorithms, including Isomap and locally linear embedding
(See In-Depth: Manifold Learning).
我们在本章中会介绍⼀些重要的降维算法,包括主成分分析(参⻅深⼊:主成分分析)和不同的流形学习算法,如Isomap和局部线性嵌⼊
(参⻅深⼊:流形学习)。
Summary
总结
Here we have seen a few simple examples of some of the basic types of machine learning approaches. Needless to say,
there are a number of important practical details that we have glossed over, but I hope this section was enough to give
you a basic idea of what types of problems machine learning approaches can solve.
本节中我们看到了⼀些基本机器学习⽅法的简单例⼦。⽆需说明也看得出来,我们只是⼀笔带过的进⾏了相关介绍,但通过本节的内容希
望能为读者提供了关于机器学习⽅法能够解决的问题类型的基本概念。
In short, we saw the following:
Supervised learning: Models that can predict labels based on labeled training data
Classification: Models that predict labels as two or more discrete categories
Regression: Models that predict continuous labels
Unsupervised learning: Models that identify structure in unlabeled data
Clustering: Models that detect and identify distinct groups in the data
Dimensionality reduction: Models that detect and identify lower-dimensional structure in higher-dimensional data
简单来说,有如下的主要⼏个⽅⾯:
有监督学习:建⽴⼀个能够根据带标记的训练数据对数据进⾏标签预测的模型
分类:建⽴⼀个能够预测两个或多个离散分组标签的模型
回归:建⽴⼀个能够预测连续标签的模型
⽆监督学习:建⽴⼀个能够识别未标记数据内在结构的模型
聚类:建⽴⼀个检查和识别数据不同分组的模型
降维:建⽴⼀个能发现⾼维度数据在低维度情况下结构的模型
In the following sections we will go into much greater depth within these categories, and see some more interesting
examples of where these concepts can be useful.
在后续章节中,我们会深⼊到上述的这些机器学习⽅法类型中,还有看到更多这些⽅法能发挥作⽤的有趣的例⼦。
All of the figures in the preceding discussion are generated based on actual machine learning computations; the code
behind them can be found in Appendix: Figure Code.
本节中所有的图像都是使⽤真实的机器学习计算⽣成的;产⽣图像的代码可以在附录:⽣成图像的代码中找到。
<
机器学习 | ⽬录 | Scikit-Learn简介 >
Open in Colab
<
什么是机器学习? | ⽬录 | 超参数及模型验证 >
Open in Colab
Introducing Scikit-Learn
Scikit-Learn
简介
There are several Python libraries which provide solid implementations of a range of machine learning algorithms. One of
the best known is Scikit-Learn, a package that provides efficient versions of a large number of common algorithms. ScikitLearn is characterized by a clean, uniform, and streamlined API, as well as by very useful and complete online
documentation. A benefit of this uniformity is that once you understand the basic use and syntax of Scikit-Learn for one
type of model, switching to a new model or algorithm is very straightforward.
中有许多软件包提供了⼀系列的机器学习算法实现。其中最知名的是Scikit-Learn,它提供了⼤量的通⽤算法的⾼效实现。Scikit提供了⼀套⼲净、统⼀和流式的API,还有⾮常实⽤及完整的在线⽂档。这种统⼀性的优点在于,⼀旦你理解了Scikit-Learn其中⼀种
模型的基本使⽤⽅法和语法,再去使⽤另⼀种模型或算法的切换过程基本是⽆痛的。
Python
Learn
This section provides an overview of the Scikit-Learn API; a solid understanding of these API elements will form the
foundation for understanding the deeper practical discussion of machine learning algorithms and approaches in the
following chapters.
本节主要对Scikit-Learn的API进⾏总体介绍;对这些API的深⼊理解和掌握需要在后续的⼩节内容中使⽤更实际的例⼦来进⾏说明。
We will start by covering data representation in Scikit-Learn, followed by covering the Estimator API, and finally go
through a more interesting example of using these tools for exploring a set of images of hand-written digits.
我们⾸先从Scikit-Learn的数据表⽰开始介绍,接下来是评估器API,最后学习使⽤这些⼯具分析⼿写数字图像的数据集,这将会是⼀个更
加有趣的例⼦帮助你来理解这些⼯具和概念。
Data Representation in Scikit-Learn
Scikit-Learn
的数据表⽰
Machine learning is about creating models from data: for that reason, we'll start by discussing how data can be
represented in order to be understood by the computer. The best way to think about data within Scikit-Learn is in terms of
tables of data.
机器学习是有关从数据创建模型的技术:因此我们⾸先讨论数据是如何表⽰的,以⽅便被计算机理解。认识Scikit-Learn中的数据最好的⽅
式是数据表。
Data as table
数据表
A basic table is a two-dimensional grid of data, in which the rows represent individual elements of the dataset, and the
columns represent quantities related to each of these elements. For example, consider the Iris dataset, famously
analyzed by Ronald Fisher in 1936. We can download this dataset in the form of a Pandas DataFrame using the
seaborn library:
基础的表是⼀个⼆维的数据⽹格,其中的⾏是数据集中每个独⽴的元素,⽽列是每个这些元素的属性值。例如我们之前使⽤过的鸢尾花数
据集,这个数据集于1936年被Ronald Fisher研究分析⽽闻名。我们可以使⽤seaborn来将这个数据集下载成⼀个Pandas的 DataFrame :
In [1]: import seaborn as sns
iris = sns.load_dataset('iris')
iris.head()
Out[1]:
sepal_length
sepal_width
petal_length
petal_width
species
0
5.1
3.5
1.4
0.2
setosa
1
4.9
3.0
1.4
0.2
setosa
2
4.7
3.2
1.3
0.2
setosa
3
4.6
3.1
1.5
0.2
setosa
4
5.0
3.6
1.4
0.2
setosa
Here each row of the data refers to a single observed flower, and the number of rows is the total number of flowers in the
dataset. In general, we will refer to the rows of the matrix as samples, and the number of rows as n_samples .
这个数据集中每⼀⾏都代表⼀朵独⽴观察的花,所以数据集的总⾏数就是观察到的花的总数量。总的来说,我们将这些⾏组成的矩阵称为
样本,总⾏数被称为 n_samples 。
Likewise, each column of the data refers to a particular quantitative piece of information that describes each sample. In
general, we will refer to the columns of the matrix as features, and the number of columns as n_features .
同样的,数据集中的每⼀列都代表我们在每个样本中观测到的特征的数值信息。于是,我们将这些列组成的矩阵称为特征,总列数被称为
n_features 。
Features matrix
特征矩阵
This table layout makes clear that the information can be thought of as a two-dimensional numerical array or matrix,
which we will call the features matrix. By convention, this features matrix is often stored in a variable named X . The
features matrix is assumed to be two-dimensional, with shape [n_samples, n_features] , and is most often
contained in a NumPy array or a Pandas DataFrame , though some Scikit-Learn models also accept SciPy sparse
matrices.
这样的表构造很清晰地表明信息是可以被想象成⼀个⼆维的数值数组或矩阵,也就是我们常说的特征矩阵。习惯上,特征矩阵通常被保存
在变量 X 中。特征矩阵被认为是⼀个形状为 [n_samples, n_features] 的⼆维矩阵,⽽且⼀般都是保存在NumPy数组或者Pandas
的 DataFrame 中,虽然⼀些Scikit-Learn模型也能接受SciPy稀疏矩阵作为输⼊。
The samples (i.e., rows) always refer to the individual objects described by the dataset. For example, the sample might
be a flower, a person, a document, an image, a sound file, a video, an astronomical object, or anything else you can
describe with a set of quantitative measurements.
这些样本(也就是⾏)永远指代数据集中的独⽴的对象。例如样本可以是花、⼈、⽂档、图像、声⾳⽂件、视频、天⽂物体或者任何你可
以使⽤⼀组数值描述的物体。
The features (i.e., columns) always refer to the distinct observations that describe each sample in a quantitative manner.
Features are generally real-valued, but may be Boolean or discrete-valued in some cases.
这些特征(也就是列)永远指代每⼀个样本中的不同特征测量数据值。特征值通常是实数,在有些情况下也可能是布尔值或离散值。
Target array
⽬标数组
In addition to the feature matrix X , we also generally work with a label or target array, which by convention we will
usually call y . The target array is usually one dimensional, with length n_samples , and is generally contained in a
NumPy array or Pandas Series . The target array may have continuous numerical values, or discrete classes/labels.
While some Scikit-Learn estimators do handle multiple target values in the form of a two-dimensional, [n_samples,
n_targets] target array, we will primarily be working with the common case of a one-dimensional target array.
除了特征矩阵 X ,我们也通常需要标签或⽬标数组,习惯上我们称它为 y 。⽬标数组⼀般是⼀维的,具有⻓度 n_samples ,⼀般保存
在⼀个⼀维NumPy数组或者Pandas的 Series 中。⽬标数组可能具有连续的数值或者离散的分类或标签。虽然⼀些Scikit-Learn评估器也
可以处理⼆维的多⽬标值,形状为 [n_samples, n_targets] 的数组,但是我们主要聚焦在⼀维⽬标数组的通常应⽤场景中。
Often one point of confusion is how the target array differs from the other features columns. The distinguishing feature of
the target array is that it is usually the quantity we want to predict from the data: in statistical terms, it is the dependent
variable. For example, in the preceding data we may wish to construct a model that can predict the species of flower
based on the other measurements; in this case, the species column would be considered the target array.
通常让⼈混淆的⼀点是⽬标数组与其他特征列的区别。⽬标数组的区别特性表现在于它们通常是我们希望⽤来预测数据的量:在统计学术
语中,它被称为因变量。例如,我们希望从上⾯的数据中构造⼀个模型⽤来从新的测量数据中预测鸢尾花的种类;在这个情况下,
species 列可以被认为是⽬标数组。
With this target array in mind, we can use Seaborn (see Visualization With Seaborn) to conveniently visualize the data:
有了⽬标数组,我们可以使⽤Seaborn(参⻅使⽤Seaborn进⾏可视化)很⽅便地可视化数据:
译者注:下⾯代码中的size参数已经过时,已经改为height
In [2]: %matplotlib inline
import seaborn as sns; sns.set()
sns.pairplot(iris, hue='species', height=1.5);
For use in Scikit-Learn, we will extract the features matrix and target array from the DataFrame , which we can do using
some of the Pandas DataFrame operations discussed in the Chapter 3:
使⽤Scikit-Learn,我们可以从 DataFrame 中提取出特征矩阵和⽬标数组,我们可以使⽤⼀些我们在第三章中介绍过的Pandas
DataFrame 技巧:
In [3]: X_iris = iris.drop('species', axis=1)
X_iris.shape
Out[3]: (150, 4)
In [4]: y_iris = iris['species']
y_iris.shape
Out[4]: (150,)
To summarize, the expected layout of features and target values is visualized in the following diagram:
下⾯的图⼤致画出了上⾯⽣成的特征矩阵和⽬标向量的情况:
附录中⽣成图像的代码
With this data properly formatted, we can move on to consider the estimator API of Scikit-Learn:
数据格式已经准备好了,我们可以继续学习Scikit-Learn的评估器API:
Scikit-Learn's Estimator API
Scikit-Learn
评估器 API
The Scikit-Learn API is designed with the following guiding principles in mind, as outlined in the Scikit-Learn API paper:
Consistency: All objects share a common interface drawn from a limited set of methods, with consistent
documentation.
Inspection: All specified parameter values are exposed as public attributes.
Limited object hierarchy: Only algorithms are represented by Python classes; datasets are represented in standard
formats (NumPy arrays, Pandas DataFrame s, SciPy sparse matrices) and parameter names use standard Python
strings.
Composition: Many machine learning tasks can be expressed as sequences of more fundamental algorithms, and
Scikit-Learn makes use of this wherever possible.
Sensible defaults: When models require user-specified parameters, the library defines an appropriate default value.
被设计成具有下述的指导原则,它们在Scikit-Learn API⽂档中有说明:
⼀致性:所有对象都共享⼀个公共的接⼝,从少量的⼀组⽅法中衍⽣出来,有着⼀致的⽂档。
有限的对象层次:只有算法被表达为Python类;数据集表⽰为标准格式(NumPy数组,Pandas DataFrame ,SciPy稀疏矩阵),参
数名称使⽤的是标准Python字符串。
组合:许多机器学习任务可以被表达为⼀系列的更基础算法,Scikit-Learn在任何可能的地⽅都可以组合使⽤它们。
明智的默认值:当模型需要⽤⼾指定的参数时,软件包预定义了合适的默认值。
Scikit-Learn API
In practice, these principles make Scikit-Learn very easy to use, once the basic principles are understood. Every machine
learning algorithm in Scikit-Learn is implemented via the Estimator API, which provides a consistent interface for a wide
range of machine learning applications.
在实践中,这些原则令Scikit-Learn⾮常易于使⽤,⼀旦理解了基本的原则。Scikit-Learn中每个机器学习算法都是通过评估器API实现的,
它为⼤范围的机器学习应⽤场景提供了⼀整套⼀致性的接⼝。
Basics of the API
API
基础
Most commonly, the steps in using the Scikit-Learn estimator API are as follows (we will step through a handful of
detailed examples in the sections that follow).
1. Choose a class of model by importing the appropriate estimator class from Scikit-Learn.
2. Choose model hyperparameters by instantiating this class with desired values.
3. Arrange data into a features matrix and target vector following the discussion above.
4. Fit the model to your data by calling the fit() method of the model instance.
5. Apply the Model to new data:
For supervised learning, often we predict labels for unknown data using the predict() method.
For unsupervised learning, we often transform or infer properties of the data using the transform() or
predict() method.
最通常的情况下,你可以依照下⾯的步骤来使⽤Scikit-Learn评估器API(我们后⾯会按照这些步骤运⾏许多详细的例⼦)。
1. 通过载⼊合适的Scikit-Learn评估器类选择⼀个模型的类型。
2. 通过使⽤需要的值作为模型的超参数来实例化模型对象。
3. 按照上⾯的⽅式将数据分为特征矩阵和⽬标向量。
4. 通过调⽤模型实例的 fit() ⽅法将你的模型与数据进⾏拟合。
5. 将拟合后的模型应⽤在新的数据上:
对于有监督学习,通常我们使⽤ predict() ⽅法来预测未知数据的标签。
对于⽆监督学习,通常我们使⽤ transform() ⽅法来转换或推断数据的属性。
We will now step through several simple examples of applying supervised and unsupervised learning methods.
下⾯我们通过⼏个简单的例⼦来简单说明有监督和⽆监督学习⽅法的应⽤。
Supervised learning example: Simple linear regression
有监督学习例⼦:简单线性回归
As an example of this process, let's consider a simple linear regression—that is, the common case of fitting a line to
(x, y) data. We will use the following simple data for our regression example:
作为第⼀个例⼦,让我们考虑简单的线性回归,也就是最常⻅的将⼀根直线拟合到
例⼦:
(x, y)
数据上。我们使⽤下⾯简单的数据来作为回归的
In [5]: import matplotlib.pyplot as plt
import numpy as np
rng = np.random.RandomState(42)
x = 10 * rng.rand(50)
y = 2 * x - 1 + rng.randn(50)
plt.scatter(x, y);
With this data in place, we can use the recipe outlined earlier. Let's walk through the process:
有了数据之后,我们就可以按照刚才的步骤来实现回归。下⾯我们⼀步⼀步的来操作:
1. Choose a class of model
1.
选择模型类型
In Scikit-Learn, every class of model is represented by a Python class. So, for example, if we would like to compute a
simple linear regression model, we can import the linear regression class:
在Scikit-Learn中,每个模型类型都是⼀个Python类。因此如果我们希望计算简单的线性回归模型,我们可以载⼊线性回归类:
In [6]: from sklearn.linear_model import LinearRegression
Note that other more general linear regression models exist as well; you can read more about them in the
sklearn.linear_model module documentation.
注意还有更多通⽤的线性回归模型;你可以在 sklearn.linear_model 模块的在线⽂档中学到更多的内容。
2. Choose model hyperparameters
2.
选择模型超参数
An important point is that a class of model is not the same as an instance of a model.
要记住的⼀个重要的点是⼀个模型的类别与⼀个模型的实例不是同⼀个东西。
Once we have decided on our model class, there are still some options open to us. Depending on the model class we are
working with, we might need to answer one or more questions like the following:
Would we like to fit for the offset (i.e., y-intercept)?
Would we like the model to be normalized?
Would we like to preprocess our features to add model flexibility?
What degree of regularization would we like to use in our model?
How many model components would we like to use?
我们决定了我们模型类别之后,还有⼀些参数可以进⾏选择。取决于我们选择的模型类别,我们可能需要回下⾯⼀个或多个问题:
我们需要拟合偏移(例如y截距)吗?
我们需要模型归⼀化吗?
我们需要预处理特征来增加模型的灵活性吗?
在我们的模型中正则化的⻆度是多少?
我们想要使⽤多少个模型的组件?
These are examples of the important choices that must be made once the model class is selected. These choices are
often represented as hyperparameters, or parameters that must be set before the model is fit to data. In Scikit-Learn,
hyperparameters are chosen by passing values at model instantiation. We will explore how you can quantitatively
motivate the choice of hyperparameters in Hyperparameters and Model Validation.
⼀旦模型类别选定后,上⾯列出的都是⼀些重要的选择。这些选择通常被称为超参数,或者解释为在模型拟合数据前被设置的参数。在
Scikit-Learn中,超参数通过向模型实例传递参数值来设置。我们会在超参数和模型验证⼀节中深⼊讨论如何定量调整这些超参数的值。
For our linear regression example, we can instantiate the LinearRegression class and specify that we would like to
fit the intercept using the fit_intercept hyperparameter:
对于我们线性回归例⼦来说,我们可以实例化 LinearRegression 类并且使⽤ fit_intercept 参数来设置你是否希望拟合截距值:
In [7]: model = LinearRegression(fit_intercept=True)
model
Out[7]: LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
Keep in mind that when the model is instantiated, the only action is the storing of these hyperparameter values. In
particular, we have not yet applied the model to any data: the Scikit-Learn API makes very clear the distinction between
choice of model and application of model to data.
记住当模型被实例化后,唯⼀的动作就是保存了超参数的值。也就是说我们还未将模型应⽤到任何数据上:Scikit-Learn API将模型选择和
将模型应⽤在数据上区分的很清楚。
3. Arrange data into a features matrix and target vector
3.
将数据组合成特征矩阵和⽬标向量
Previously we detailed the Scikit-Learn data representation, which requires a two-dimensional features matrix and a onedimensional target array. Here our target variable y is already in the correct form (a length- n_samples array), but we
need to massage the data x to make it a matrix of size [n_samples, n_features] . In this case, this amounts to a
simple reshaping of the one-dimensional array:
前⾯我们详细介绍了Scikit-Learn数据表⽰,它需要⼀个⼆维的特征矩阵和⼀个⼀维的⽬标数组。这⾥我们的⽬标变量 y 已经是正确格式了
(⻓度为 n_samples 的数组),但是我们需要将数据 x 变成⼀个形状为 [n_samples, n_features] 的矩阵。在这个情况下,我们
需要将⼀个⼀维数组进⾏变形:
In [8]: X = x[:, np.newaxis]
X.shape
Out[8]: (50, 1)
4. Fit the model to your data
4.
将模型拟合数据
Now it is time to apply our model to data. This can be done with the fit() method of the model:
现在是时候将我们的模型应⽤在数据上了。这可以通过模型的 fit() ⽅法实现:
In [9]: model.fit(X, y)
Out[9]: LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)
This fit() command causes a number of model-dependent internal computations to take place, and the results of
these computations are stored in model-specific attributes that the user can explore. In Scikit-Learn, by convention all
model parameters that were learned during the fit() process have trailing underscores; for example in this linear
model, we have the following:
执⾏ fit() ⽅法会导致⼀系列的模型内部计算,计算得到的结果会保存在模型对象的属性上,⽤⼾可以查看它们。在Scikit-Learn中习惯
上所有通过 fit() 过程学习得到的模型参数都有下划线后缀;例如在这个线性模型中,我们有下⾯的属性:
In [10]: model.coef_
Out[10]: array([1.9776566])
In [11]: model.intercept_
Out[11]: -0.9033107255311164
These two parameters represent the slope and intercept of the simple linear fit to the data. Comparing to the data
definition, we see that they are very close to the input slope of 2 and intercept of -1.
这两个参数代表着我们拟合数据后得到的斜率和截距。回想我们的数据定义,我们很容易看出它们很接近输⼊的斜率2和截距-1.
One question that frequently comes up regards the uncertainty in such internal model parameters. In general, ScikitLearn does not provide tools to draw conclusions from internal model parameters themselves: interpreting model
parameters is much more a statistical modeling question than a machine learning question. Machine learning rather
focuses on what the model predicts. If you would like to dive into the meaning of fit parameters within the model, other
tools are available, including the Statsmodels Python package.
⼀个经常被提到的问题就是关于这样的模型内部参数的不确定性。通常来说,Scikit-Learn不提供⼯具来对内部模型参数本⾝进⾏概括:解
释模型参数更多是⼀个统计模型问题⽽⾮⼀个机器学习问题。机器学习更加聚焦的是模型预测的内容。如果你希望深⼊了解模型拟合参数
的含义,可以使⽤别的⼯具,包括统计模型 Python 包。
5. Predict labels for unknown data
5.
对未知数据进⾏预测
Once the model is trained, the main task of supervised machine learning is to evaluate it based on what it says about
new data that was not part of the training set. In Scikit-Learn, this can be done using the predict() method. For the
sake of this example, our "new data" will be a grid of x values, and we will ask what y values the model predicts:
⼀旦模型训练好了,有监督机器学习的主要任务就是⽤它来评估不属于训练集的数据结果。在Scikit-Learn中,可以通过 predict() ⽅法
来实现。在这个例⼦中,我们的“新数据”是⼀个x值的⽹格,我们使⽤模型来预测出相应的y值:
In [12]: xfit = np.linspace(-1, 11)
As before, we need to coerce these x values into a [n_samples, n_features] features matrix, after which we can
feed it to the model:
如前所述,我们需要将这个x向量转变成⼀个 [n_samples, n_features] 的特征矩阵,然后才能使⽤模型进⾏预测:
In [13]: Xfit = xfit[:, np.newaxis]
yfit = model.predict(Xfit)
Finally, let's visualize the results by plotting first the raw data, and then this model fit:
最后,让我们在图表中画出原始数据的散点和新数据的直线:
In [14]: plt.scatter(x, y)
plt.plot(xfit, yfit);
Typically the efficacy of the model is evaluated by comparing its results to some known baseline, as we will see in the
next example
模型的性能可以通过对结果和已知的基线进⾏⽐较来评估,我们会在下⼀个例⼦中看到。
Supervised learning example: Iris classification
有监督学习例⼦:鸢尾花分类
Let's take a look at another example of this process, using the Iris dataset we discussed earlier. Our question will be this:
given a model trained on a portion of the Iris data, how well can we predict the remaining labels?
让我们再通过⼀个例⼦来介绍这个过程,本例中我们使⽤前⾯的鸢尾花数据集。我们的问题是:给定鸢尾花数据集的⼀部分⽤来训练模
型,我们能多好的预测剩余数据的标签?
For this task, we will use an extremely simple generative model known as Gaussian naive Bayes, which proceeds by
assuming each class is drawn from an axis-aligned Gaussian distribution (see In Depth: Naive Bayes Classification for
more details). Because it is so fast and has no hyperparameters to choose, Gaussian naive Bayes is often a good model
to use as a baseline classification, before exploring whether improvements can be found through more sophisticated
models.
对于这个任务来说,我们会使⽤⼀个极端简单的⽣成模型,称为⾼斯朴素⻉叶斯模型,它的算法思想就是假设每个分类都可以从轴对⻬的
⾼斯分布获得(参⻅深⼊:朴素⻉叶斯分类)。这个模型速度极快并且没有需要选择的超参数,因此⾼斯朴素⻉叶斯经常可以⽤来作为⼀
个基准分类模型,在我们使⽤更复杂的模型进⾏性能优化之前优先使⽤它。
We would like to evaluate the model on data it has not seen before, and so we will split the data into a training set and a
testing set. This could be done by hand, but it is more convenient to use the train_test_split utility function:
我们希望通过模型没有训练到的数据对它的性能进⾏评估,因此我们需要将数据分为训练集和测试集。这可以通过⼿⼯完成,还可以使⽤
train_test_split ⼯具函数很⽅便的实现:
译者注:下⾯代码将过时的cross_validation修改为model_selection
In [15]: from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(X_iris, y_iris,
random_state=1)
With the data arranged, we can follow our recipe to predict the labels:
数据准备好后,我们可以依照步骤对测试集数据的标签进⾏预测:
In [16]: from sklearn.naive_bayes import GaussianNB # 1. 选择模型类别
model = GaussianNB()
# 2. 实例化模型
model.fit(Xtrain, ytrain)
# 3. 拟合数据
y_model = model.predict(Xtest)
# 4. 预测新数据
Finally, we can use the accuracy_score utility to see the fraction of predicted labels that match their true value:
最后,我们可以通过 accuracy_score ⼯具来查看有多少⽐例的标签我们是预测正确的:
In [17]: from sklearn.metrics import accuracy_score
accuracy_score(ytest, y_model)
Out[17]: 0.9736842105263158
With an accuracy topping 97%, we see that even this very naive classification algorithm is effective for this particular
dataset!
准确率⾼达97%,可以看到对于这个数据集来说即使如此简单的分类算法也可以⾮常有效。
Unsupervised learning example: Iris dimensionality
⽆监督学习例⼦:鸢尾花数据集降维
As an example of an unsupervised learning problem, let's take a look at reducing the dimensionality of the Iris data so as
to more easily visualize it. Recall that the Iris data is four dimensional: there are four features recorded for each sample.
作为⽆监督学习问题的例⼦,我们来看⼀下对鸢尾花数据集进⾏降维处理令它们更容易可视化。我们都已经知道鸢尾花数据集有四个维
度:也就是每个样本都记录了四个特征的数据。
The task of dimensionality reduction is to ask whether there is a suitable lower-dimensional representation that retains
the essential features of the data. Often dimensionality reduction is used as an aid to visualizing data: after all, it is much
easier to plot data in two dimensions than in four dimensions or higher!
降维的任务是找出是否有⼀种合适的低纬度数据表⽰能基本保留了数据的关键特征。通常降维都被⽤来帮助数据可视化:毕竟在⼆维数据
上作图肯定⽐在四维甚⾄更⾼维度上作图容易的多。
Here we will use principal component analysis (PCA; see In Depth: Principal Component Analysis), which is a fast linear
dimensionality reduction technique. We will ask the model to return two components—that is, a two-dimensional
representation of the data.
这⾥我们会使⽤主成分分析(PCA;参⻅深⼊:主成分分析),它是⼀个快速的线性降维⽅法。我们会要求模型返回两个组成部分,即数
据的⼆维表⽰。
Following the sequence of steps outlined earlier, we have:
依照前⾯介绍的步骤,我们可以:
In [18]: from sklearn.decomposition import PCA
model = PCA(n_components=2)
model.fit(X_iris)
X_2D = model.transform(X_iris)
# 1.
# 2.
# 3.
# 4.
选择模型类别
实例化模型,设置超参数
拟合数据,注意这⾥没有y参数
将数据转换为⼆维
Now let's plot the results. A quick way to do this is to insert the results into the original Iris DataFrame , and use
Seaborn's lmplot to show the results:
下⾯绘制结果。最简单的⽅式是将结果作为列插⼊回原始的鸢尾花 DataFrame ,然后使⽤Seaborn的 lmplot 来展⽰结果:
In [19]: iris['PCA1'] = X_2D[:, 0]
iris['PCA2'] = X_2D[:, 1]
sns.lmplot("PCA1", "PCA2", hue='species', data=iris, fit_reg=False);
We see that in the two-dimensional representation, the species are fairly well separated, even though the PCA algorithm
had no knowledge of the species labels! This indicates to us that a relatively straightforward classification will probably be
effective on the dataset, as we saw before.
我们发现在⼆维数据表⽰中,花的种类也是很容易分开的,即使在PCA算法对于种类标签根本没有了解。这也体现了这个数据集可以相对
直接的进⾏分类,就像前⾯看到的那样。
Unsupervised learning: Iris clustering
⽆监督学习:鸢尾花数据集聚类
Let's next look at applying clustering to the Iris data. A clustering algorithm attempts to find distinct groups of data without
reference to any labels. Here we will use a powerful clustering method called a Gaussian mixture model (GMM),
discussed in more detail in In Depth: Gaussian Mixture Models. A GMM attempts to model the data as a collection of
Gaussian blobs.
下⾯我们来看看将聚类算法应⽤在鸢尾花数据集上的情况。聚类算法试图在没有任何标签的数据集中找出不同的分组。下⾯我们会使⽤⼀
个强⼤的聚类⽅法称为⾼斯混合模型(GMM),我们会在深⼊:⾼斯混合模型中详细介绍它。GMM试图将数据看成是⼀组⾼斯族群。
We can fit the Gaussian mixture model as follows:
我们可以如下拟合⾼斯混合模型:
译者注:GMM因为过时,下⾯代码已修改为GaussianMixture
选择模型类型
实例化模型,设置超参数
拟合数据,注意y没有设置
预测值
In [20]: from sklearn.mixture import GaussianMixture
model = GaussianMixture(n_components=3,
covariance_type='full') # 2.
model.fit(X_iris)
# 3.
y_gmm = model.predict(X_iris)
# 4.
# 1.
As before, we will add the cluster label to the Iris DataFrame and use Seaborn to plot the results:
想之前⼀样,我们会给鸢尾花 DataFrame 添加聚类列,然后使⽤Seaborn绘制结果:
In [21]: iris['cluster'] = y_gmm
sns.lmplot("PCA1", "PCA2", data=iris, hue='species',
col='cluster', fit_reg=False);
By splitting the data by cluster number, we see exactly how well the GMM algorithm has recovered the underlying label:
the setosa species is separated perfectly within cluster 0, while there remains a small amount of mixing between
versicolor and virginica. This means that even without an expert to tell us the species labels of the individual flowers, the
measurements of these flowers are distinct enough that we could automatically identify the presence of these different
groups of species with a simple clustering algorithm! This sort of algorithm might further give experts in the field clues as
to the relationship between the samples they are observing.
使⽤聚类编号将数据分开,我们可以清楚的看到GMM算法运⾏的多么良好:setosa种类被完美地分到了群组0,剩下的versicolor和
virginica有⼀点混在⼀起,但是也⽐较准确。这意味着即使在没有专家告诉我们如何区分不同种类的花的情况下,我们也可以使⽤计算机
⾃动根据聚类算法将它们区分出来。这种算法还可以为专家提供他们观测的样本之间联系的线索。
Application: Exploring Hand-written Digits
应⽤:分析⼿写数字
To demonstrate these principles on a more interesting problem, let's consider one piece of the optical character
recognition problem: the identification of hand-written digits. In the wild, this problem involves both locating and identifying
characters in an image. Here we'll take a shortcut and use Scikit-Learn's set of pre-formatted digits, which is built into the
library.
下⾯我们要在⼀个更加有趣的问题中展⽰这些⽅法,考虑⼀个图像识别的问题:⼿写数字的⾃动识别。正常情况下,这个问题包括了定位
和识别图像中的字⺟。这⾥我们抄了⼀个捷径,使⽤Scikit-Learn⾃带的预处理过的图像。
Loading and visualizing the digits data
载⼊和展⽰数字图像
We'll use Scikit-Learn's data access interface and take a look at this data:
我们使⽤Scikit-Learn的数据访问接⼝来载⼊这些图像并且查看⼀下数据内容:
In [22]: from sklearn.datasets import load_digits
digits = load_digits()
digits.images.shape
Out[22]: (1797, 8, 8)
The images data is a three-dimensional array: 1,797 samples each consisting of an 8 × 8 grid of pixels. Let's visualize the
first hundred of these:
图像数据是三维数组:1797个样本每个包括8 × 8像素的图。我们可以展⽰头100张:
In [23]: import matplotlib.pyplot as plt
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(digits.target[i]),
transform=ax.transAxes, color='green')
In order to work with this data within Scikit-Learn, we need a two-dimensional, [n_samples, n_features]
representation. We can accomplish this by treating each pixel in the image as a feature: that is, by flattening out the pixel
arrays so that we have a length-64 array of pixel values representing each digit. Additionally, we need the target array,
which gives the previously determined label for each digit. These two quantities are built into the digits dataset under the
data and target attributes, respectively:
为了要在Scikit-Learn中使⽤这个数据集,我们需要⼀个⼆维的 [n_samples, n_features] 数据表⽰。在本例中我们可以将图像中的
每个像素点当成⼀个特征:也就是说,通过将每个图像的像素数组平铺展开成⼀个⻓度为64的⼀维数组。除此之外,我们还需要⽬标数
组,如上图⼀样是每张图标记的数字组成的数组。这两个量已经在数据集中內建好了,分别叫做 data 和 target 属性:
In [24]: X = digits.data
X.shape
Out[24]: (1797, 64)
In [25]: y = digits.target
y.shape
Out[25]: (1797,)
We see here that there are 1,797 samples and 64 features.
我们看到⼀共有1797个样本和64个特征。
Unsupervised learning: Dimensionality reduction
⽆监督学习:降维
We'd like to visualize our points within the 64-dimensional parameter space, but it's difficult to effectively visualize points
in such a high-dimensional space. Instead we'll reduce the dimensions to 2, using an unsupervised method. Here, we'll
make use of a manifold learning algorithm called Isomap (see In-Depth: Manifold Learning), and transform the data to
two dimensions:
我们希望能够将我们的点在⼀个64维的参数空间中可视化出来,但是在这么⾼的维度上有效的可视化是⾮常困难的。所以我们转⽽使⽤⽆
监督⽅法将维度减⾄⼆维。这⾥我们使⽤的是流形学习算法Isomap(参⻅深⼊: 流形学习),然后将数据转换成⼆维:
选择模型类别
实例化模型,设置超参数
拟合数据,这⾥也没有 参数
转换数据到⼆维
In [26]: from sklearn.manifold import Isomap #
iso = Isomap(n_components=2) #
iso.fit(digits.data) #
y
data_projected = iso.transform(digits.data) #
data_projected.shape
Out[26]: (1797, 2)
We see that the projected data is now two-dimensional. Let's plot this data to see if we can learn anything from its
structure:
我们看到映射后的数据现在是⼆维的了。下⾯我们把降维后的数据绘制出来看我们学习的成果:
In [27]: plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,
edgecolor='none', alpha=0.5,
cmap=plt.cm.get_cmap('Spectral', 10))
plt.colorbar(label='digit label', ticks=range(10))
plt.clim(-0.5, 9.5);
This plot gives us some good intuition into how well various numbers are separated in the larger 64-dimensional space.
For example, zeros (in black) and ones (in purple) have very little overlap in parameter space. Intuitively, this makes
sense: a zero is empty in the middle of the image, while a one will generally have ink in the middle. On the other hand,
there seems to be a more or less continuous spectrum between ones and fours: we can understand this by realizing that
some people draw ones with "hats" on them, which cause them to look similar to fours.
上图给我们展现了数据集在⾼维度-64维空间很直观的分布情况展⽰。例如数字0和1在特征矩阵空间很少重叠。这很容易理解:0在图像中
间有个空⽩区域,⽽1中间没有空⽩区域。另⼀⽅⾯,数字1和4⼏乎有着很连续的图谱:当我们⼀直到⼀些⼈写数字1时会加上“帽⼦”时,
这就容易理解了,这回造成两者看起来很相似。
Overall, however, the different groups appear to be fairly well separated in the parameter space: this tells us that even a
very straightforward supervised classification algorithm should perform suitably on this data. Let's give it a try.
⼤体来说,上图说明不同的数字在它们的特征矩阵空间中都能较好的区分开:这表⽰即使是⼀个很直接简单的有监督分类算法应该也能适
合分类这个数据集。让我们试⼀试。
Classification on digits
数字分类
Let's apply a classification algorithm to the digits. As with the Iris data previously, we will split the data into a training and
testing set, and fit a Gaussian naive Bayes model:
下⾯我们在⼿写数字上应⽤分类算法。就像前⾯鸢尾花数据那样,我们将数据集分为训练集和测试集,然后将这些训练数据拟合到⾼斯朴
素⻉叶斯模型中:
In [28]: Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)
In [29]: from sklearn.naive_bayes import GaussianNB
model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)
Now that we have predicted our model, we can gauge its accuracy by comparing the true values of the test set to the
predictions:
我们已经预测了我们的模型,我们可以将得到的预测结果和测试集的⽬标向量进⾏⽐较得到模型的准确率:
In [30]: from sklearn.metrics import accuracy_score
accuracy_score(ytest, y_model)
Out[30]: 0.8333333333333334
With even this extremely simple model, we find about 80% accuracy for classification of the digits! However, this single number doesn't
tell us where we've gone wrong—one nice way to do this is to use the confusion matrix, which we can compute with Scikit-Learn and plot
with Seaborn:
使⽤这个⾮常简单的模型,我们得到了⼤约80%的数字分类的准确率。然⽽这个数字并不能告诉我们哪⾥出错了,输出混淆矩阵是⼀个好
办法,可以使⽤Scikit-Learn计算它并使⽤Seaborn绘制图表:
In [31]: from sklearn.metrics import confusion_matrix
mat = confusion_matrix(ytest, y_model)
sns.heatmap(mat, square=True, annot=True, cbar=False)
plt.xlabel('predicted value')
plt.ylabel('true value');
This shows us where the mis-labeled points tend to be: for example, a large number of twos here are mis-classified as
either ones or eights. Another way to gain intuition into the characteristics of the model is to plot the inputs again, with
their predicted labels. We'll use green for correct labels, and red for incorrect labels:
上图为我们展⽰了哪些数字更容易被错误标记:例如⽐较多的数字2被错误分类到了数字1或数字8。另⼀种直观展⽰模型准确率的⽅法是
绘制输⼊的数字图像,还有它们预测的标签。我们使⽤绿⾊展⽰预测正确的标签,红⾊展⽰错误的标签:
In [32]: fig, axes = plt.subplots(10, 10, figsize=(8, 8),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
test_images = Xtest.reshape(-1, 8, 8)
for i, ax in enumerate(axes.flat):
ax.imshow(test_images[i], cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(y_model[i]),
transform=ax.transAxes,
color='green' if (ytest[i] == y_model[i]) else 'red')
Examining this subset of the data, we can gain insight regarding where the algorithm might be not performing optimally.
To go beyond our 80% classification rate, we might move to a more sophisticated algorithm such as support vector
machines (see In-Depth: Support Vector Machines), random forests (see In-Depth: Decision Trees and Random Forests)
or another classification approach.
通过检查这个数据⼦集,我们也能获得算法在什么情况下变现的不尽⼈意。要获得超越80%分类准确率,我们需要转向更复杂的算法例如
⽀持向量机(参⻅深⼊:⽀持向量机)、随机森林(参⻅深⼊:随机森林)或其他分类⽅法。
Summary
总结
In this section we have covered the essential features of the Scikit-Learn data representation, and the estimator API.
Regardless of the type of estimator, the same import/instantiate/fit/predict pattern holds. Armed with this information
about the estimator API, you can explore the Scikit-Learn documentation and begin trying out various models on your
data.
在本节中我们介绍了Scikit-Learn数据表⽰⽅式和评估器API的基本概念和使⽤⽅法。⽆论使⽤哪种评估器,载⼊/实例化/拟合/预测这些步
骤都是⼀样的。掌握了评估器API这些信息后,你可以⾃⼰阅读Scikit-Learn⽂档以及开始在数据上尝试使⽤不同的模型。
In the next section, we will explore perhaps the most important topic in machine learning: how to select and validate your
model.
在下⼀节中,我们会讨论也许是本章机器学习中最重要的课题:如何选择和验证你的模型。
<
什么是机器学习? | ⽬录 | 超参数及模型验证 >
Open in Colab
< Scikit-Learn
简介 | ⽬录 | 特征⼯程 >
Open in Colab
Hyperparameters and Model Validation
超参数和模型验证
In the previous section, we saw the basic recipe for applying a supervised machine learning model:
1. Choose a class of model
2. Choose model hyperparameters
3. Fit the model to the training data
4. Use the model to predict labels for new data
上⼀节中,我们学习了应⽤有监督机器学习模型的基本配⽅:
1. 选择⼀个模型类别
2. 选择模型超参数
3. 将模型拟合到训练数据上
4. 将模型在新数据上进⾏预测
The first two pieces of this—the choice of model and choice of hyperparameters—are perhaps the most important part of
using these tools and techniques effectively. In order to make an informed choice, we need a way to validate that our
model and our hyperparameters are a good fit to the data. While this may sound simple, there are some pitfalls that you
must avoid to do this effectively.
上⾯的前两步,选择模型类别和超参数,也许是有效使⽤这些⼯具和技术的最关键部分。为了作出⼀个明智的选择,我们需要⼀个⽅式来
验证我们的模型和超参数,看它们是否拟合数据集。虽然这个⽅式听起来很简单,但是⾥⾯有很多坑你需要避开。
Thinking about Model Validation
思考模型验证
In principle, model validation is very simple: after choosing a model and its hyperparameters, we can estimate how
effective it is by applying it to some of the training data and comparing the prediction to the known value.
原则上,模型验证⾮常简单:选择了模型类别和它的超参数之后,我们将它应⽤到⼀些训练数据上进⾏训练,然后将它的预测值和已知值
进⾏⽐较。
The following sections first show a naive approach to model validation and why it fails, before exploring the use of holdout
sets and cross-validation for more robust model evaluation.
下⾯⾸先介绍⼀个原始的模型验证⽅法和为什么它不正确,然后再介绍使⽤预留的⼦集及交叉验证⽅法来获得更健壮的模型评估结果。
Model validation the wrong way
错误的模型验证
Let's demonstrate the naive approach to validation using the Iris data, which we saw in the previous section. We will start
by loading the data:
让我们展⽰使⽤鸢尾花数据集来进⾏模型验证的⼀个原始⽅法,⾸先导⼊数据:
In [1]: from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
Next we choose a model and hyperparameters. Here we'll use a k-neighbors classifier with n_neighbors=1 . This is a
very simple and intuitive model that says "the label of an unknown point is the same as the label of its closest training
point:"
下⾯我们选择模型和超参数。这⾥我们会使⽤k近邻分类器,超参数 n_neighbors=1 。这是⼀个⾮常简单和直观的模型,它认为“未知的
点的标签与距离它最近的训练点的标签是⼀样的”。
In [2]: from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(n_neighbors=1)
Then we train the model, and use it to predict labels for data we already know:
然后我们训练模型,⽤训练好的模型来预测训练集的标签:
In [3]: model.fit(X, y)
y_model = model.predict(X)
Finally, we compute the fraction of correctly labeled points:
最后,我们计算得到准确率:
In [4]: from sklearn.metrics import accuracy_score
accuracy_score(y, y_model)
Out[4]: 1.0
We see an accuracy score of 1.0, which indicates that 100% of points were correctly labeled by our model! But is this
truly measuring the expected accuracy? Have we really come upon a model that we expect to be correct 100% of the
time?
我们看到准确率是1.0,这表⽰100%的点都被我们的模型正确标记了。但是这是否代表这我们模型真实的准确率?我们是否真的能碰到⼀
个模型能够100%正确的处理所有数据?
As you may have gathered, the answer is no. In fact, this approach contains a fundamental flaw: it trains and evaluates
the model on the same data. Furthermore, the nearest neighbor model is an instance-based estimator that simply stores
the training data, and predicts labels by comparing new data to these stored points: except in contrived cases, it will get
100% accuracy every time!
你直觉上应该就能知道答案是否定的。事实上,这个⽅法有着⼀个最基本的错误:使⽤同样的数据集来训练和评估性能。并且,最近邻模
型是⼀个基于实例的评估器保存训练数据,然后将这些新数据和保存的数据点进⾏⽐较来预测标签:除⾮在⼈为⼲预情况下,它总是会获
得100%的准确率。
Model validation the right way: Holdout sets
模型验证的正确⽅式:保留部分数据
So what can be done? A better sense of a model's performance can be found using what's known as a holdout set: that
is, we hold back some subset of the data from the training of the model, and then use this holdout set to check the model
performance. This splitting can be done using the train_test_split utility in Scikit-Learn:
那么应该怎么做?将⼀部分数据集保留出来不参与训练,并使⽤它们对模型的性能进⾏评估才是正确的办法:意思就是我们将数据中的部
分⼦集从训练集中分离出来,然后再将它们预测的结果和预先标记的结果进⾏对⽐得到模型性能。这可以通过Scikit-Learn的
train_test_split ⼯具完成:
In [5]: from sklearn.model_selection import train_test_split
#
50%
X1, X2, y1, y2 = train_test_split(X, y, random_state=0,
train_size=0.5)
将数据集按照 分成两个⼦集 训练集和测试集
使⽤训练集对模型进⾏拟合
#
model.fit(X1, y1)
使⽤模型对测试集进⾏预测,并评估结果
#
y2_model = model.predict(X2)
accuracy_score(y2, y2_model)
Out[5]: 0.9066666666666666
We see here a more reasonable result: the nearest-neighbor classifier is about 90% accurate on this hold-out set. The
hold-out set is similar to unknown data, because the model has not "seen" it before.
这样我们就得到了⼀个更加合理的结果:最近邻分类器在这样划分了训练集和测试集后,能得到⼤约90%的准确率。这⾥保留出的⼦数据
集类似未知的数据,因为模型根本没有⻅过它们。
Model validation via cross-validation
对模型使⽤交叉验证
One disadvantage of using a holdout set for model validation is that we have lost a portion of our data to the model
training. In the above case, half the dataset does not contribute to the training of the model! This is not optimal, and can
cause problems – especially if the initial set of training data is small.
上⾯的保留⼦数据集来验证模型的⽅式有⼀个缺点,那就是我们其中⼀部分的数据⽆法参与模型训练过程。在上⾯例⼦中,⼀半的数据集
对于训练模型没有任何贡献。这不是最优化的⽅式,⽽且可能导致问题,特别是原始训练数据规模⽐较⼩的情况下。
One way to address this is to use cross-validation; that is, to do a sequence of fits where each subset of the data is used
both as a training set and as a validation set. Visually, it might look something like this:
解决这个缺点的⽅法是使⽤交叉验证;也就是使⽤⼀系列的拟合过程,其中每次拟合的时候都是⽤完整的数据集,但是不同的训练集和测
试集来进⾏验证。下⾯描绘了这个过程:
附录中⽣成图像的代码
Here we do two validation trials, alternately using each half of the data as a holdout set. Using the split data from before,
we could implement it like this:
这⾥我们使⽤两次验证过程,每次使⽤不同的⼀半数据作为保留的数据集来验证模型。使⽤上⾯分好的数据,我们使⽤下⾯的代码实现:
In [6]: y2_model = model.fit(X1, y1).predict(X2)
y1_model = model.fit(X2, y2).predict(X1)
accuracy_score(y1, y1_model), accuracy_score(y2, y2_model)
Out[6]: (0.96, 0.9066666666666666)
What comes out are two accuracy scores, which we could combine (by, say, taking the mean) to get a better measure of
the global model performance. This particular form of cross-validation is a two-fold cross-validation—that is, one in which
we have split the data into two sets and used each in turn as a validation set.
上⾯输出了两个准确率结果,我们可以组合(例如通过取平均值)来获得更好的全局性能结果。上⾯这个特殊的交叉验证过程被称为双重
检查验证,也就是我们将数据均分为两个⼦数据集,然后依次使⽤它们作为测试集。
We could expand on this idea to use even more trials, and more folds in the data—for example, here is a visual depiction
of five-fold cross-validation:
我们可以将这个⽅法扩展到更多的拟合过程,将数据划分为更多⼦集进⾏更多重训练验证,例如下图是⼀个五重交叉验证:
附录中⽣成图像的代码
Here we split the data into five groups, and use each of them in turn to evaluate the model fit on the other 4/5 of the data.
This would be rather tedious to do by hand, and so we can use Scikit-Learn's cross_val_score convenience routine
to do it succinctly:
这⾥我们将数据分成5组,每次使⽤其中⼀组来评估模型,其余的4/5⽤来训练模型。每次都要⼿动完成这项⼯作是很⽆聊的,因此我们可
以使⽤Scikit-Learn的 cross_val_score ⼯具来直接完成它:
In [7]: from sklearn.model_selection import cross_val_score
cross_val_score(model, X, y, cv=5)
Out[7]: array([0.96666667, 0.96666667, 0.93333333, 0.93333333, 1.
])
Repeating the validation across different subsets of the data gives us an even better idea of the performance of the
algorithm.
使⽤不同的⼦数据集重复对模型进⾏验证能为我们提供更好的算法性能结果。
Scikit-Learn implements a number of useful cross-validation schemes that are useful in particular situations; these are
implemented via iterators in the cross_validation module. For example, we might wish to go to the extreme case in
which our number of folds is equal to the number of data points: that is, we train on all points but one in each trial. This
type of cross-validation is known as leave-one-out cross validation, and can be used as follows:
实现了许多有⽤的交叉验证⽅案,它们适合于特定的场景;这些⽅案都是在 model_selection 模块中实现的。例如,我们
可能希望采⽤⼀种极端的⽅案,该⽅案中数据的分组等于数据的样本数:也就是说,我们使⽤除了⼀个数据点外的其他所有数据进⾏训
练。这种交叉验证被成为leave-one-out交叉验证,如下例:
译者注:本翻译将所有旧版的 cross_validation 模块改为了 model_selection 模块,下⾯的LeaveOneOut实例化也修改为⽆参数
的新版本。
Scikit-Learn
In [8]: from sklearn.model_selection import LeaveOneOut
scores = cross_val_score(model, X, y, cv=LeaveOneOut())
scores
Out[8]: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Because we have 150 samples, the leave one out cross-validation yields scores for 150 trials, and the score indicates
either successful (1.0) or unsuccessful (0.0) prediction. Taking the mean of these gives an estimate of the error rate:
因为我们有150个样本,leave-one-out交叉验证会得到150个验证结果,结果只有两种状态:验证成功(1.0)或验证失败(0.0)。对上⾯
的结果数组求平均值能得到⼀个估计的准确率:
In [9]: scores.mean()
Out[9]: 0.96
Other cross-validation schemes can be used similarly. For a description of what is available in Scikit-Learn, use IPython
to explore the sklearn.cross_validation submodule, or take a look at Scikit-Learn's online cross-validation
documentation.
其他交叉验证⽅案也是类似的使⽤。想要查阅Scikit-Learn中可⽤的交叉验证⽅案,可以使⽤IPython来浏览 sklean.model_selection
模块或者浏览Scikit-Learn在线交叉验证⽂档。
Selecting the Best Model
选择最佳模型
Now that we've seen the basics of validation and cross-validation, we will go into a litte more depth regarding model
selection and selection of hyperparameters. These issues are some of the most important aspects of the practice of
machine learning, and I find that this information is often glossed over in introductory machine learning tutorials.
我们已经学习了模型验证和交叉验证的基础,现在可以继续深⼊了解模型选择和超参数选择的内容。这些话题是机器学习实践中最重要的
⼀些内容,作者发现这部分信息经常在机器学习⼊⻔教程中被⼀笔带过。
Of core importance is the following question: if our estimator is underperforming, how should we move forward? There
are several possible answers:
Use a more complicated/more flexible model
Use a less complicated/less flexible model
Gather more training samples
Gather more data to add features to each sample
其中的核⼼问题是:我们的评估器是否表现不佳?我们应该如何继续改进?这可能有如下的答案:
使⽤⼀个更加复杂或更加灵活的模型
使⽤⼀个没那么复杂或没那么灵活的模型
收集更多的训练样本
对每个样本收集更多信息,增加特征
The answer to this question is often counter-intuitive. In particular, sometimes using a more complicated model will give
worse results, and adding more training samples may not improve your results! The ability to determine what steps will
improve your model is what separates the successful machine learning practitioners from the unsuccessful.
对这个问题的解答经常是反直觉的。⽐⽅说有时使⽤⼀个更复杂的模型可能会得到⼀个更差的结果,⽽增加样本数量不能改进你的结果。
决定采⽤哪些⽅法步骤来改进模型的能⼒是成功的机器学习实践者和不成功的实践者之间的主要区别。
The Bias-variance trade-off
偏差⽅差的权衡
Fundamentally, the question of "the best model" is about finding a sweet spot in the tradeoff between bias and variance.
Consider the following figure, which presents two regression fits to the same dataset:
最佳模型”问题根本上是关于寻找偏差和⽅差的最佳均衡点。考虑下图,这是对同⼀个数据集的两个回归:
附录中⽣成图像的代码
“
It is clear that neither of these models is a particularly good fit to the data, but they fail in different ways.
很明显这两个模型都不是拟合数据的最佳模型,但是它们失败的地⽅是不⼀样的。
The model on the left attempts to find a straight-line fit through the data. Because the data are intrinsically more
complicated than a straight line, the straight-line model will never be able to describe this dataset well. Such a model is
said to underfit the data: that is, it does not have enough model flexibility to suitably account for all the features in the
data; another way of saying this is that the model has high bias.
左边的模型试图找出⼀条直线来拟合数据。因为这个数据很明显⽐直线要复杂的多,因此直线模型不可能很好的描述这个数据集。这样的
模型我们称为“⽋拟合”:也就是说,它没有提供⾜够的模型灵活性来反映出数据的所有特征;⽤另⼀种说法就是这个模型有着⾼的偏差。
The model on the right attempts to fit a high-order polynomial through the data. Here the model fit has enough flexibility to
nearly perfectly account for the fine features in the data, but even though it very accurately describes the training data, its
precise form seems to be more reflective of the particular noise properties of the data rather than the intrinsic properties
of whatever process generated that data. Such a model is said to overfit the data: that is, it has so much model flexibility
that the model ends up accounting for random errors as well as the underlying data distribution; another way of saying
this is that the model has high variance.
右边的模型试图使⽤⼀个⾼阶的多项式来拟合数据。这个模型具有⾜够的灵活性⼏乎完美的描述了数据中的特征,虽然它很精确的描述了
训练数据,但是它的这种精确性可能更多反映了对数据噪声特征的反映⽽不是内在主要特征的反映。这样的模型被称为是“过拟合”的:也
就是说它有着很好的模型灵活性,甚⾄反映了数据的随机误差;另⼀种说法就是这个模型有着⾼的⽅差。
To look at this in another light, consider what happens if we use these two models to predict the y-value for some new
data. In the following diagrams, the red/lighter points indicate data that is omitted from the training set:
从另⼀个⻆度来看,如果我们使⽤这两个模型来预测⼀些新数据的y值的话,下图中的红⾊的点代表这数据中从训练集中分出来的数据点:
附录中⽣成图像的代码
The score here is the $R^2$ score, or coefficient of determination, which measures how well a model performs relative to
a simple mean of the target values. $R^2=1$ indicates a perfect match, $R^2=0$ indicates the model does no better than
simply taking the mean of the data, and negative values mean even worse models. From the scores associated with
these two models, we can make an observation that holds more generally:
For high-bias models, the performance of the model on the validation set is similar to the performance on the training
set.
For high-variance models, the performance of the model on the validation set is far worse than the performance on
the training set.
这⾥使⽤的评估标准是$R^2$分值,或者称为决定系数,计算的是相对⽬标值的简单平均值差距,⽤来衡量模型预测性能的好⽅法。
$R^2=1$代表完全复合,$R^2=0$代表模型与简单取数据平均值没有区别,负数值代表模型的表现还不如简单取平均值。从两个模型的这
个分值中,我们可以得到更普遍的结论:
对于⾼偏差的模型来说,模型的性能在测试集上与在训练集上类似。
对于⾼⽅差的模型来说,模型的性能在测试集上⽐在训练集上差了⾮常多。
If we imagine that we have some ability to tune the model complexity, we would expect the training score and validation
score to behave as illustrated in the following figure:
如果想象我们有某种能⼒能够调节模型的复杂度,我们可以绘制下⾯的图形代表着训练分数和测试分数的情况:
附录中⽣成图像的代码
The diagram shown here is often called a validation curve, and we see the following essential features:
The training score is everywhere higher than the validation score. This is generally the case: the model will be a
better fit to data it has seen than to data it has not seen.
For very low model complexity (a high-bias model), the training data is under-fit, which means that the model is a
poor predictor both for the training data and for any previously unseen data.
For very high model complexity (a high-variance model), the training data is over-fit, which means that the model
predicts the training data very well, but fails for any previously unseen data.
For some intermediate value, the validation curve has a maximum. This level of complexity indicates a suitable
trade-off between bias and variance.
这幅图像通常被称为验证曲线,我们观察到下⾯这些关键特征:
训练分数在任何地⽅都⽐验证分数要⾼。这基于:模型在它⻅过的数据上会⽐它没⻅过的数据上更加拟合。
对于低复杂度模型(⾼偏差模型)来说,训练数据是⽋拟合的,这代表着模型既不能很好的预测训练数据也不能很好的预测未知数
据。
对于⾮常⾼复杂度模型(⾼⽅法模型)来说,训练数据是过拟合的,这代表着模型能⾮常好的预测训练数据,但是不能很好的预测未
知数据。
对于中间部分来说,验证曲线有⼀个最⼤值。这个点代表着偏差和⽅差的最佳平衡点。
The means of tuning the model complexity varies from model to model; when we discuss individual models in depth in
later sections, we will see how each model allows for such tuning.
不同的模型有着不同的模型复杂度调整的含义;当我们在后⾯章节深⼊讨论单个模型时,我们会看到每种模型在这⽅⾯的调整⽅法。
Validation curves in Scikit-Learn
Scikit-Learn
中的验证曲线
Let's look at an example of using cross-validation to compute the validation curve for a class of models. Here we will use
a polynomial regression model: this is a generalized linear model in which the degree of the polynomial is a tunable
parameter. For example, a degree-1 polynomial fits a straight line to the data; for model parameters $a$ and $b$:
下⾯我们来看⼀个例⼦说明使⽤交叉验证来计算⼀种模型的验证曲线。这⾥我们将使⽤多项式回归模型:这是⼀个⼴义的线性模型,其中
的多项式的阶是可调整的参数。例如,⼀阶的多项式将数据拟合到⼀条直线上;模型参数有$a$和$b$:
$$ y = ax + b $$
A degree-3 polynomial fits a cubic curve to the data; for model parameters $a, b, c, d$:
⼀个三阶的多项式将数据是配到⼀条三次⽅程曲线上;模型参数有$a, b, c, d$:
$$ y = ax^3 + bx^2 + cx + d $$
We can generalize this to any number of polynomial features. In Scikit-Learn, we can implement this with a simple linear
regression combined with the polynomial preprocessor. We will use a pipeline to string these operations together (we will
discuss polynomial features and pipelines more fully in Feature Engineering):
我们可以推⼴到任意阶的多项式中。在Scikit-Learn中我们可以通过将线性回归与多项式预处理器结合起来实现这个任务。我们会使⽤管道
将这些操作串联起来(我们会在特征⼯程中更详细的讨论多项式特征和管道):
In [10]: from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
def PolynomialRegression(degree=2, **kwargs):
return make_pipeline(PolynomialFeatures(degree),
LinearRegression(**kwargs))
Now let's create some data to which we will fit our model:
现在我们让我们创建⼀些数据来拟合模型:
In [11]: import numpy as np
def make_data(N, err=1.0, rseed=1):
#
rng = np.random.RandomState(rseed)
X = rng.rand(N, 1) ** 2
y = 10 - 1. / (X.ravel() + 0.1)
if err > 0:
y += err * rng.randn(N)
return X, y
随机数据取样⽣成
X, y = make_data(40)
We can now visualize our data, along with polynomial fits of several degrees:
然后对数据进⾏可视化,包含着不同阶的多项式匹配结果:
In [12]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() #
⽤Seaborn可视化
X_test = np.linspace(-0.1, 1.1, 500)[:, None]
plt.scatter(X.ravel(), y, color='black')
axis = plt.axis()
for degree in [1, 3, 5]:
y_test = PolynomialRegression(degree).fit(X, y).predict(X_test)
plt.plot(X_test.ravel(), y_test, label='degree={0}'.format(degree))
plt.xlim(-0.1, 1.0)
plt.ylim(-2, 12)
plt.legend(loc='best');
The knob controlling model complexity in this case is the degree of the polynomial, which can be any non-negative
integer. A useful question to answer is this: what degree of polynomial provides a suitable trade-off between bias (underfitting) and variance (over-fitting)?
这个例⼦中控制模型复杂度的开关就是多项式的阶数,可以使任何⾮负的整数。这⾥关键的问题是:哪个阶的多项式在偏差(⽋拟合)和
⽅差(过拟合)之间达到了合适的平衡?
We can make progress in this by visualizing the validation curve for this particular data and model; this can be done
straightforwardly using the validation_curve convenience routine provided by Scikit-Learn. Given a model, data,
parameter name, and a range to explore, this function will automatically compute both the training score and validation
score across the range:
我们还可以进⼀步将这个特殊的数据和模型的验证曲线绘制出来;这可以直接通过Scikit-Learn提供的 validation_curve ⼯具完成。给
定模型、数据、参数名称和⼀个范围,这个函数能够⾃动计算范围内所有的训练分数和验证分数:
In [13]: from sklearn.model_selection import validation_curve
degree = np.arange(0, 21)
train_score, val_score = validation_curve(PolynomialRegression(), X, y,
'polynomialfeatures__degree', degree, cv=7)
plt.plot(degree, np.median(train_score, 1), color='blue', label='training score')
plt.plot(degree, np.median(val_score, 1), color='red', label='validation score')
plt.legend(loc='best')
plt.ylim(0, 1)
plt.xlabel('degree')
plt.ylabel('score');
This shows precisely the qualitative behavior we expect: the training score is everywhere higher than the validation score;
the training score is monotonically improving with increased model complexity; and the validation score reaches a
maximum before dropping off as the model becomes over-fit.
上图精确的展⽰了我们期望的定量⾏为:训练分数在任何地⽅都⾼于验证分数;训练分数是⼀个单调递增函数,随着模型复杂度增加⽽增
加;然⽽验证分数在达到最⼤值后会因为过拟合⽽开始下降。
From the validation curve, we can read-off that the optimal trade-off between bias and variance is found for a third-order
polynomial; we can compute and display this fit over the original data as follows:
从验证曲线中,我们可以看到最优的偏差和⽅差平衡出现在三阶的多项式附近;我们可以在原始数据上计算并展⽰这个模型:
In [14]: plt.scatter(X.ravel(), y)
lim = plt.axis()
y_test = PolynomialRegression(3).fit(X, y).predict(X_test)
plt.plot(X_test.ravel(), y_test);
plt.axis(lim);
Notice that finding this optimal model did not actually require us to compute the training score, but examining the
relationship between the training score and validation score can give us useful insight into the performance of the model.
请注意寻找这个最优模型并不需要计算训练分数,但是检验训练分数和验证分数之间的关系能为我们提供模型性能的内在含义。
Learning Curves
学习曲线
One important aspect of model complexity is that the optimal model will generally depend on the size of your training
data. For example, let's generate a new dataset with a factor of five more points:
对于模型复杂度来说⼀个重要的相关性是它依赖于训练数据的规模。例如,我们创建⼀个数据集,具有5倍数量的样本:
In [15]: X2, y2 = make_data(200)
plt.scatter(X2.ravel(), y2);
We will duplicate the preceding code to plot the validation curve for this larger dataset; for reference let's over-plot the
previous results as well:
我们重复前⾯的代码来绘制这个⼤的数据集的验证曲线;为了对⽐我们将前⾯的结果也⽤虚线画出来:
In [16]: degree = np.arange(21)
train_score2, val_score2 = validation_curve(PolynomialRegression(), X2, y2,
'polynomialfeatures__degree', degree, cv=7)
plt.plot(degree, np.median(train_score2, 1), color='blue', label='training score')
plt.plot(degree, np.median(val_score2, 1), color='red', label='validation score')
plt.plot(degree, np.median(train_score, 1), color='blue', alpha=0.3, linestyle='dashed')
plt.plot(degree, np.median(val_score, 1), color='red', alpha=0.3, linestyle='dashed')
plt.legend(loc='lower center')
plt.ylim(0, 1)
plt.xlabel('degree')
plt.ylabel('score');
The solid lines show the new results, while the fainter dashed lines show the results of the previous smaller dataset. It is
clear from the validation curve that the larger dataset can support a much more complicated model: the peak here is
probably around a degree of 6, but even a degree-20 model is not seriously over-fitting the data—the validation and
training scores remain very close.
实线展⽰新的结果,⽽虚线展⽰的是前⾯⼩数据集的结果。从验证曲线很明显看出⼤的数据集能够⽀持更复杂的模型:上图中的峰值⼤约
出现在阶数6的位置上,但是甚⾄到了20阶的多项式模型中,也没有出现严重的过拟合,验证分数和训练分数依然很接近。
Thus we see that the behavior of the validation curve has not one but two important inputs: the model complexity and the
number of training points. It is often useful to to explore the behavior of the model as a function of the number of training
points, which we can do by using increasingly larger subsets of the data to fit our model. A plot of the training/validation
score with respect to the size of the training set is known as a learning curve.
因此我们看到了验证曲线不⽌有⼀个⽽是有两个重要的输⼊参数:模型复杂度和数据样本量。研究模型的性能与样本量之间的关系函数经
常也很有帮助,我们可以通过不断增加数据中⽤来训练的⼦数据集规模来进⾏研究。绘制⼀幅训练/验证分数随着训练集规模变化的图像被
称为学习曲线。
The general behavior we would expect from a learning curve is this:
A model of a given complexity will overfit a small dataset: this means the training score will be relatively high, while
the validation score will be relatively low.
A model of a given complexity will underfit a large dataset: this means that the training score will decrease, but the
validation score will increase.
A model will never, except by chance, give a better score to the validation set than the training set: this means the
curves should keep getting closer together but never cross.
从学习曲线中我们⼀般可以观察到下⾯的结论:
在⼩数据集的情况下,⼀个给定复杂度的模型很可能会过拟合:这意味着训练分数相对来说⽐较⾼⽽验证分数⽐较低。
在⼤数据集的情况下,⼀个给定复杂度的模型很可能会⽋拟合:这意味着训练分数会下降⽽验证分数会上升。
⼀个模型应该永远(除⾮很偶然的情况下)在训练集给出⽐测试集更⾼的分值:这意味着两根曲线会⼀直接近但是不会相交。
With these features in mind, we would expect a learning curve to look qualitatively like that shown in the following figure:
有了上述结论,我们预计的学习曲线如下图:
附录中⽣成图像的代码
The notable feature of the learning curve is the convergence to a particular score as the number of training samples
grows. In particular, once you have enough points that a particular model has converged, adding more training data will
not help you! The only way to increase model performance in this case is to use another (often more complex) model.
学习曲线的⼀个著名特征就是当训练样本量增加时,两根曲线会收敛。这意味着,⼀旦你已经有了⾜够的样本量使得某种模型已经收敛的
话,增加更多的训练数据不会提供任何帮助。在这种情况下提升模型性能的唯⼀⽅法就是使⽤另⼀个(通常更复杂)的模型。
Learning curves in Scikit-Learn
Scikit-Learn
中的学习曲线
Scikit-Learn offers a convenient utility for computing such learning curves from your models; here we will compute a
learning curve for our original dataset with a second-order polynomial model and a ninth-order polynomial:
Scikit-Learn
线:
提供了⼀个⽅便的⼯具来计算模型的学习曲线;下⾯我们计算我们原始数据集在⼆阶多项式模型和九阶多项式模型上的学习曲
In [17]: from sklearn.model_selection import learning_curve
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
for i, degree in enumerate([2, 9]):
N, train_lc, val_lc = learning_curve(PolynomialRegression(degree),
X, y, cv=7,
train_sizes=np.linspace(0.3, 1, 25))
ax[i].plot(N, np.mean(train_lc, 1), color='blue', label='training score')
ax[i].plot(N, np.mean(val_lc, 1), color='red', label='validation score')
ax[i].hlines(np.mean([train_lc[-1], val_lc[-1]]), N[0], N[-1],
color='gray', linestyle='dashed')
ax[i].set_ylim(0, 1)
ax[i].set_xlim(N[0], N[-1])
ax[i].set_xlabel('training size')
ax[i].set_ylabel('score')
ax[i].set_title('degree = {0}'.format(degree), size=14)
ax[i].legend(loc='best')
This is a valuable diagnostic, because it gives us a visual depiction of how our model responds to increasing training
data. In particular, when your learning curve has already converged (i.e., when the training and validation curves are
already close to each other) adding more training data will not significantly improve the fit! This situation is seen in the left
panel, with the learning curve for the degree-2 model.
这是⼀项⾮常有价值的分析,因为它为我们提供了模型随着训练数据增加发⽣性能变化的可视化展⽰。⽽且当你的学习曲线已经收敛时
(例如当训练和验证曲线已经⾮常接近的情况下)增加更多的训练数据不会显著的提升拟合度。这个结论很容易从左图⼆阶模型的学习曲
线中获得。
The only way to increase the converged score is to use a different (usually more complicated) model. We see this in the
right panel: by moving to a much more complicated model, we increase the score of convergence (indicated by the
dashed line), but at the expense of higher model variance (indicated by the difference between the training and validation
scores). If we were to add even more data points, the learning curve for the more complicated model would eventually
converge.
要提升已经收敛的学习曲线的性能唯⼀⽅法就是使⽤⼀个不同的(通常更复杂的)模型。我们可以从右图中看到:当使⽤了复杂的多的模
型后,我们将收敛的分数值(使⽤虚线表⽰)提升了,付出的代价是更⾼的模型⽅差(图中训练曲线和验证曲线的间距)。如果我们继续
增加更多的样本,更复杂模型的学习曲线最终也会收敛。
Plotting a learning curve for your particular choice of model and dataset can help you to make this type of decision about
how to move forward in improving your analysis.
绘制模型和数据集的学习曲线能帮助你作出进⼀步改善性能的决定。
Validation in Practice: Grid Search
验证实践:⽹格搜索
The preceding discussion is meant to give you some intuition into the trade-off between bias and variance, and its
dependence on model complexity and training set size. In practice, models generally have more than one knob to turn,
and thus plots of validation and learning curves change from lines to multi-dimensional surfaces. In these cases, such
visualizations are difficult and we would rather simply find the particular model that maximizes the validation score.
前⾯的讨论意在为你提供直观的偏差和⽅差权衡的知识,它取决于模型复杂度和训练集规模。在实践中,模型通常有多于⼀个开关进⾏调
节,因此前⾯关于验证曲线和学习曲线的⼆维线条就会变成多维平⾯。在这些情况下,要将它可视化出来是很困难的,并且我们更希望简
单的找到特定模型能最⼤化验证分数。
Scikit-Learn provides automated tools to do this in the grid search module. Here is an example of using grid search to
find the optimal polynomial model. We will explore a three-dimensional grid of model features; namely the polynomial
degree, the flag telling us whether to fit the intercept, and the flag telling us whether to normalize the problem. This can
be set up using Scikit-Learn's GridSearchCV meta-estimator:
提供了⾃动化的⼯具来完成这项任务,它们在⽹格搜索模块中。下⾯是⼀个使⽤⽹格搜索找到最优多项式模型的例⼦。我们会
探索模型特征的⼀个三维⽹格;包括多项式阶数,⼀个是否拟合截距的标志和⼀个是否归⼀化问题的标志。这可以通过Scikit-Learn的
GridSearchCV 元评估器来设置:
Scikit-Learn
In [18]: from sklearn.model_selection import GridSearchCV
param_grid = {'polynomialfeatures__degree': np.arange(21),
'linearregression__fit_intercept': [True, False],
'linearregression__normalize': [True, False]}
grid = GridSearchCV(PolynomialRegression(), param_grid, cv=7)
Notice that like a normal estimator, this has not yet been applied to any data. Calling the fit() method will fit the
model at each grid point, keeping track of the scores along the way:
⽹格搜索模型和普通模型⼀样,实例化后还未应⽤到任何数据上。通过调⽤ fit() ⽅法会将模型的每个⽹格点拟合到数据上,同时过程
中保存了验证的分数:
In [19]: grid.fit(X, y);
Now that this is fit, we can ask for the best parameters as follows:
拟合完后,我们可以使⽤下⾯代码来获得最佳参数:
In [20]: grid.best_params_
Out[20]: {'linearregression__fit_intercept': False,
'linearregression__normalize': True,
'polynomialfeatures__degree': 4}
Finally, if we wish, we can use the best model and show the fit to our data using code from before:
最终,需要的话,我们可以使⽤代码将最佳模型、数据及它们的拟合情况绘制出来:
In [21]: model = grid.best_estimator_
plt.scatter(X.ravel(), y)
lim = plt.axis()
y_test = model.fit(X, y).predict(X_test)
plt.plot(X_test.ravel(), y_test);
plt.axis(lim);
The grid search provides many more options, including the ability to specify a custom scoring function, to parallelize the
computations, to do randomized searches, and more. For information, see the examples in In-Depth: Kernel Density
Estimation and Feature Engineering: Working with Images, or refer to Scikit-Learn's grid search documentation.
⽹格搜索提供很多其他参数,包括指定⾃定义的评分函数,并⾏化计算和执⾏随机搜索等等。需要更多信息,参⻅深⼊:核密度估计和特
征⼯程,或者参考Scikit-Learn的⽹格搜索在线⽂档。
Summary
总结
In this section, we have begun to explore the concept of model validation and hyperparameter optimization, focusing on
intuitive aspects of the bias–variance trade-off and how it comes into play when fitting models to data. In particular, we
found that the use of a validation set or cross-validation approach is vital when tuning parameters in order to avoid overfitting for more complex/flexible models.
在本节中,我们开始探讨模型验证和超参数优化的概念,聚焦在偏差⽅差权衡的直观概念和它在模型拟合数据时扮演的⻆⾊。特别是,我
们强调使⽤测试集验证和交叉验证⽅法的重要性,当在复杂/灵活模型中调节参数时要避免过拟合。
In later sections, we will discuss the details of particularly useful models, and throughout will talk about what tuning is
available for these models and how these free parameters affect model complexity. Keep the lessons of this section in
mind as you read on and learn about these machine learning approaches!
在后续章节中,我们会讨论每种模型的细节,并在过程中介绍这些模型可以调节哪些参数以及这些参数如何影响模型复杂度。请将本节的
内容牢记,当你在后⾯继续学习机器学习⽅法的时候,本节内容会提供重要的帮助。
< Scikit-Learn
简介 | ⽬录 | 特征⼯程 >
Open in Colab
<
超参数及模型验证 | ⽬录 | 深⼊:朴素⻉叶斯分类 >
Open in Colab
Feature Engineering
特征⼯程
The previous sections outline the fundamental ideas of machine learning, but all of the examples assume that you have
numerical data in a tidy, [n_samples, n_features] format. In the real world, data rarely comes in such a form. With
this in mind, one of the more important steps in using machine learning in practice is feature engineering: that is, taking
whatever information you have about your problem and turning it into numbers that you can use to build your feature
matrix.
上⼏节中我们描述了机器学习的基本概念,但前⾯所有的例⼦都假定你的数据都是数值的,并且具有⼲净的形状为 [n_samples,
n_features] 格式。在现实世界中,数据很少具有这样的格式。有了这个前提,要在实践中使⽤机器学习其中⼀个重要的步骤就是特征
⼯程:也就是使⽤你拿到的数据,将它们转换为数值形式,以便你可以⽤来在特征矩阵中使⽤它们。
In this section, we will cover a few common examples of feature engineering tasks: features for representing categorical
data, features for representing text, and features for representing images. Additionally, we will discuss derived features
for increasing model complexity and imputation of missing data. Often this process is known as vectorization, as it
involves converting arbitrary data into well-behaved vectors.
在本节中我们会介绍⼀些特征⼯程任务的通⽤例⼦:表⽰分类数据的特征,表⽰⽂字的特征和表⽰图像的特征。除此之外我们还会讨论派
⽣特征⽤于增加模型复杂度和对缺失值进⾏插值。通常这个过程被称为向量化,因为它意味着将任意数据转变成格式良好的向量。
Categorical Features
分类特征
One common type of non-numerical data is categorical data. For example, imagine you are exploring some data on
housing prices, and along with numerical features like "price" and "rooms", you also have "neighborhood" information.
For example, your data might look something like this:
⾮数值数据的⼀个常⻅类型是分类数据。例如,假设你在研究房价的数据,数据集中除了数值特征如“价格”和“房间数”之外,还有会有例
如“邻近地区”这样的信息。下⾯例⼦展⽰了这个数据的可能情况:
In [2]: data = [
{'price': 850000, 'rooms': 4, 'neighborhood': 'Queen Anne'},
{'price': 700000, 'rooms': 3, 'neighborhood': 'Fremont'},
{'price': 650000, 'rooms': 3, 'neighborhood': 'Wallingford'},
{'price': 600000, 'rooms': 2, 'neighborhood': 'Fremont'}
]
You might be tempted to encode this data with a straightforward numerical mapping:
你可能想要将这个数据直接进⾏数值类型的编码:
In [3]: {'Queen Anne': 1, 'Fremont': 2, 'Wallingford': 3};
It turns out that this is not generally a useful approach in Scikit-Learn: the package's models make the fundamental
assumption that numerical features reflect algebraic quantities. Thus such a mapping would imply, for example, that
Queen Anne < Fremont < Wallingford, or even that Wallingford - Queen Anne = Fremont, which (niche demographic
jokes aside) does not make much sense.
这在Scikit-Learn中不是⼀个实⽤的⽅法:包中的模型基本上假设数值特征表⽰的都是算术量。因此这样的映射会暗⽰⽐如Queen Anne <
Fremont < Wallingford,甚⾄Wallingford - Queen Anne = Fremont,这种转换没有任何含义。
In this case, one proven technique is to use one-hot encoding, which effectively creates extra columns indicating the
presence or absence of a category with a value of 1 or 0, respectively. When your data comes as a list of dictionaries,
Scikit-Learn's DictVectorizer will do this for you:
在这种情况下,有⼀种证明过的技巧可以使⽤被称为one-hot encoding,它能有效的创建额外的列代表⼀个类别的存在或缺失,分别使⽤
数值1或0表⽰。如果你的数据是⼀个字典的列表格式,Scikit-Learn的 DictVectorizer 可以帮你完成这项⼯作:
In [4]: from sklearn.feature_extraction import DictVectorizer
vec = DictVectorizer(sparse=False, dtype=int)
vec.fit_transform(data)
Out[4]: array([[
[
[
[
0,
1,
0,
1,
1,
0,
0,
0,
0, 850000,
0, 700000,
1, 650000,
0, 600000,
4],
3],
3],
2]])
Notice that the 'neighborhood' column has been expanded into three separate columns, representing the three
neighborhood labels, and that each row has a 1 in the column associated with its neighborhood. With these categorical
features thus encoded, you can proceed as normal with fitting a Scikit-Learn model.
上⾯的变换之后'neighborhood'列已经被扩展成为3个独⽴的列,分别代表三个邻近地区的标签,然后每⾏中1所在的列的位置与邻近地区相
关。经过这样的分类特征编码后,你就可以使⽤Scikit-Learn模型进⾏拟合数据了。
To see the meaning of each column, you can inspect the feature names:
要查看每个列的含义,你可以列出特征名称:
In [5]: vec.get_feature_names()
Out[5]: ['neighborhood=Fremont',
'neighborhood=Queen Anne',
'neighborhood=Wallingford',
'price',
'rooms']
There is one clear disadvantage of this approach: if your category has many possible values, this can greatly increase
the size of your dataset. However, because the encoded data contains mostly zeros, a sparse output can be a very
efficient solution:
这种⽅法有⼀个明显的缺点:如果你的分类特征有很多可能的取值,这会极⼤增加你的数据集的⼤⼩。但是因为编码后的数据⼤部分都是0
值,因此输出结果作为稀疏矩阵是⾮常⾼效的:
In [6]: vec = DictVectorizer(sparse=True, dtype=int)
vec.fit_transform(data)
Out[6]: <4x5 sparse matrix of type '<class 'numpy.int64'>'
with 12 stored elements in Compressed Sparse Row format>
Many (though not yet all) of the Scikit-Learn estimators accept such sparse inputs when fitting and evaluating models.
sklearn.preprocessing.OneHotEncoder and sklearn.feature_extraction.FeatureHasher are two
additional tools that Scikit-Learn includes to support this type of encoding.
许多(虽然不是全部)Scikit-Learn评估器接受这样的稀疏输⼊作为模型拟合及预测的参数。
sklearn.preprocessing.OneHotEncoder 和 sklearn.feature_extraction.FeatureHasher 是另外两个额外的⼯具⽀持这
种编码。
Text Features
⽂字特征
Another common need in feature engineering is to convert text to a set of representative numerical values. For example,
most automatic mining of social media data relies on some form of encoding the text as numbers. One of the simplest
methods of encoding data is by word counts: you take each snippet of text, count the occurrences of each word within it,
and put the results in a table.
另外⼀种特征⼯程常⻅的需求是将⽂字转换成⼀组代表它们的数字值。例如⼤多数社交媒体数据的⾃动挖掘都依赖于某种形式的⽂字到数
字的编码转换。其中最简单的⽅法是进⾏单词计数:选取每⼀⼩段⽂字,计算⾥⾯每个单词出现的次数,然后将它们放到表中。
For example, consider the following set of three phrases:
以下⾯的三个短语为例:
In [7]: sample = ['problem of evil',
'evil queen',
'horizon problem']
For a vectorization of this data based on word count, we could construct a column representing the word "problem," the
word "evil," the word "horizon," and so on. While doing this by hand would be possible, the tedium can be avoided by
using Scikit-Learn's CountVectorizer :
想要将上⾯的数据使⽤单词计数进⾏向量化,我们可以构造⼀个列代表单词"problem", ⼀个列代表单词"evil",⼀个列代表单
词"horizon"等等。虽然可以⼿⼯完成这项任务,但是你可以使⽤Scikit-Learn的 CountVectorizer 将⾃⼰从重复劳动中解放出来:
In [8]: from sklearn.feature_extraction.text import CountVectorizer
vec = CountVectorizer()
X = vec.fit_transform(sample)
X
Out[8]: <3x5 sparse matrix of type '<class 'numpy.int64'>'
with 7 stored elements in Compressed Sparse Row format>
The result is a sparse matrix recording the number of times each word appears; it is easier to inspect if we convert this to
a DataFrame with labeled columns:
结果是⼀个稀疏矩阵,它记录了每个单词出现的次数;我们将它转成⼀个 DataFrame 就可以很⽅便的看到数据:
In [9]: import pandas as pd
pd.DataFrame(X.toarray(), columns=vec.get_feature_names())
Out[9]:
evil
horizon
of
problem
queen
0
1
0
1
1
0
1
1
0
0
0
1
2
0
1
0
1
0
There are some issues with this approach, however: the raw word counts lead to features which put too much weight on
words that appear very frequently, and this can be sub-optimal in some classification algorithms. One approach to fix this
is known as term frequency-inverse document frequency (TF–IDF) which weights the word counts by a measure of how
often they appear in the documents. The syntax for computing these features is similar to the previous example:
然⽽这种处理⽅法有⼀些问题:原始的单词计数会导致特征在频繁出现的单词上放置了太多的权重,这对于⼀些分类算法来说是不够准确
的。解决这个问题的其中⼀个办法是被称为term frequency-inverse document frequency (TF–IDF)的算法,它会将根据单词在⽂档中出现
的频率单词计数的权重。计算这些特征的语法与前⾯的例⼦类似:
In [10]: from sklearn.feature_extraction.text import TfidfVectorizer
vec = TfidfVectorizer()
X = vec.fit_transform(sample)
pd.DataFrame(X.toarray(), columns=vec.get_feature_names())
Out[10]:
evil
horizon
of
problem
queen
0
0.517856
0.000000
0.680919
0.517856
0.000000
1
0.605349
0.000000
0.000000
0.000000
0.795961
2
0.000000
0.795961
0.000000
0.605349
0.000000
For an example of using TF-IDF in a classification problem, see In Depth: Naive Bayes Classification.
使⽤TF-IDF在分类问题中的例⼦,可参⻅深⼊:朴素⻉叶斯分类。
Image Features
图像特征
Another common need is to suitably encode images for machine learning analysis. The simplest approach is what we
used for the digits data in Introducing Scikit-Learn: simply using the pixel values themselves. But depending on the
application, such approaches may not be optimal.
还有⼀种常⻅的需求是将图像编码成适合机器学习分析的数据。最简单的⽅法在Scikit-Learn简介中已经看到过:直接使⽤图像的像素数
据。但是根据应⽤场景不同,这种⽅法可能不是最优的。
A comprehensive summary of feature extraction techniques for images is well beyond the scope of this section, but you
can find excellent implementations of many of the standard approaches in the Scikit-Image project. For one example of
using Scikit-Learn and Scikit-Image together, see Feature Engineering: Working with Images.
图像中特征提取技术的完整介绍远远超出了本节的范围,但是你可以在Scikit-Image项⽬中找到⼀流的标准⽅法实现。参⻅特征⼯程:使⽤
图像数据中可以看到联合使⽤Scikit-Learn和Scikit-Image的例⼦。
Derived Features
派⽣特征
Another useful type of feature is one that is mathematically derived from some input features. We saw an example of this
in Hyperparameters and Model Validation when we constructed polynomial features from our input data. We saw that we
could convert a linear regression into a polynomial regression not by changing the model, but by transforming the input!
This is sometimes known as basis function regression, and is explored further in In Depth: Linear Regression.
另⼀个有⽤的特征类型是从其他输⼊特征中进⾏数学计算并派⽣获得的特征。我们已经在超参数与模型验证中看到了⼀个例⼦,我们从输
⼊数据中构造了多项式特征。该例中我们看到能够将⼀个线性回归转变成⼀个多项式回归,这不是通过改变模型实现的,⽽是通过转变输
⼊数据实现的。这有时被称为基本函数回归,深⼊:线性回归⼀节中会更加深⼊讨论这⽅⾯内容。
For example, this data clearly cannot be well described by a straight line:
例如这个数据显然⽆法使⽤直线很好的拟合:
In [11]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
x = np.array([1, 2, 3, 4, 5])
y = np.array([4, 2, 1, 3, 7])
plt.scatter(x, y);
Still, we can fit a line to the data using LinearRegression and get the optimal result:
我们仍然可以使⽤ LinearRegression 将直线拟合到数据上:
In [12]: from sklearn.linear_model import LinearRegression
X = x[:, np.newaxis]
model = LinearRegression().fit(X, y)
yfit = model.predict(X)
plt.scatter(x, y)
plt.plot(x, yfit);
It's clear that we need a more sophisticated model to describe the relationship between x and y.
很显然我们需要更加复杂的模型来描述x和y之间的关系。
One approach to this is to transform the data, adding extra columns of features to drive more flexibility in the model. For
example, we can add polynomial features to the data this way:
⼀种⽅案是转换数据,通过增加额外的特征列来增加模型的灵活性。例如,我们可以如下⽅式增加多项式特征:
In [13]: from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=3, include_bias=False)
X2 = poly.fit_transform(X)
print(X2)
[[
[
[
[
[
1.
2.
3.
4.
5.
1.
1.]
4.
8.]
9. 27.]
16. 64.]
25. 125.]]
The derived feature matrix has one column representing x , and a second column representing x2 , and a third column
representing x3 . Computing a linear regression on this expanded input gives a much closer fit to our data:
派⽣的矩阵中第⼀列代表 ,第⼆列代表 ,第三列代表 。在这个扩增输⼊上计算得到的线性回归得到对数据更好的⼀个拟合:
x
2
x
3
x
In [14]: model = LinearRegression().fit(X2, y)
yfit = model.predict(X2)
plt.scatter(x, y)
plt.plot(x, yfit);
This idea of improving a model not by changing the model, but by transforming the inputs, is fundamental to many of the
more powerful machine learning methods. We explore this idea further in In Depth: Linear Regression in the context of
basis function regression. More generally, this is one motivational path to the powerful set of techniques known as kernel
methods, which we will explore in In-Depth: Support Vector Machines.
上⾯这种不通过改变模型本⾝⽽是通过转换输⼊数据的⽅法,是很多强⼤的机器学习算法的基础。我们会在深⼊:线性回归⼀节中在基础
函数回归部分更加详细的讨论它。更普遍的情况下,这种⽅法属于⼀组强⼤的被称为核⽅法的⾏动步骤的⼀部分,我们会在深⼊:⽀持向
量机中深⼊讨论它们。
Imputation of Missing Data
缺失数据插值
Another common need in feature engineering is handling of missing data. We discussed the handling of missing data in
DataFrame s in Handling Missing Data, and saw that often the NaN value is used to mark missing values. For
example, we might have a dataset that looks like this:
特征⼯程中还有⼀个普遍需求是处理缺失数据。我们在 DataFrame 的处理缺失数据⼀节中讨论过它,⽽且我们也知道通常我们使⽤ NaN
来代表缺失值。例如我们有如下的数据集:
In [15]: from numpy import nan
X = np.array([[ nan, 0,
3 ],
[ 3,
7,
9 ],
[ 3,
5,
2 ],
[ 4,
nan, 6 ],
[ 8,
8,
1 ]])
y = np.array([14, 16, -1, 8, -5])
When applying a typical machine learning model to such data, we will need to first replace such missing data with some
appropriate fill value. This is known as imputation of missing values, and strategies range from simple (e.g., replacing
missing values with the mean of the column) to sophisticated (e.g., using matrix completion or a robust model to handle
such data).
如果我们想要将典型机器学习模型应⽤到这个数据上时,我们需要⾸先将缺失数据填充上值。这被成为缺失数据的插值,它的策略从简单
(例如使⽤列均值填充缺失值)到复杂(例如使⽤矩阵补全或⼀个健壮的模型来处理这些数据)都有。
The sophisticated approaches tend to be very application-specific, and we won't dive into them here. For a baseline
imputation approach, using the mean, median, or most frequent value, Scikit-Learn provides the Imputer class:
复杂的⽅法⼀般都是应⽤场景相关的,我们在这⾥不会深⼊研究它们。对于插值的基础⽅法,如使⽤均值、中位数或最常⻅值,ScikitLearn提供了 Imputer 类:
译者注: Imputer 类已经过时,0.22版本会被移除,下⾯使⽤了 sklearn.impute.SimpleImputer 替换了原代码中的 Imputer 。
In [17]: from sklearn.impute import SimpleImputer
imp = SimpleImputer(strategy='mean')
X2 = imp.fit_transform(X)
X2
Out[17]: array([[4.5, 0. , 3. ],
[3. , 7. , 9. ],
[3. , 5. , 2. ],
[4. , 5. , 6. ],
[8. , 8. , 1. ]])
We see that in the resulting data, the two missing values have been replaced with the mean of the remaining values in
the column. This imputed data can then be fed directly into, for example, a LinearRegression estimator:
我们可以从结果看到,两个缺失值被替换成了该列的平均值。处理完后的数据就能直接被传递给评估器模型处理,例如线性回归
LinearRegression :
In [18]: model = LinearRegression().fit(X2, y)
model.predict(X2)
Out[18]: array([13.14869292, 14.3784627 , -1.15539732, 10.96606197, -5.33782027])
Feature Pipelines
特征管道操作
With any of the preceding examples, it can quickly become tedious to do the transformations by hand, especially if you
wish to string together multiple steps. For example, we might want a processing pipeline that looks something like this:
1. Impute missing values using the mean
2. Transform features to quadratic
3. Fit a linear regression
看完前⾯的例⼦之后,如果每次我们都要⼿动处理特征并且实例化模型的话,这项⼯作会变得很乏味,尤其是如果你希望将多个步骤串联
在⼀起的情况下。例如我们可能希望按顺序完成下列任务:
1. 使⽤均值插⼊缺失值
2. 将特征转换为⼆阶多项式
3. 选择和实例化线性回归模型
To streamline this type of processing pipeline, Scikit-Learn provides a Pipeline object, which can be used as follows:
你可以将这些操作使⽤管道连接起来,Scikit-Learn提供了⼀个 Pipeline 对象,使⽤⽅式如下:
In [19]: from sklearn.pipeline import make_pipeline
model = make_pipeline(SimpleImputer(strategy='mean'),
PolynomialFeatures(degree=2),
LinearRegression())
This pipeline looks and acts like a standard Scikit-Learn object, and will apply all the specified steps to any input data.
管道对象看起来就像标准的Scikit-Learn对象⼀样,它能将其中的所有操作步骤应⽤在(拟合)任意数据数据。
In [20]: model.fit(X, y) # X with missing values, from above
print(y)
print(model.predict(X))
[14 16 -1 8 -5]
[14. 16. -1. 8. -5.]
All the steps of the model are applied automatically. Notice that for the simplicity of this demonstration, we've applied the
model to the data it was trained on; this is why it was able to perfectly predict the result (refer back to Hyperparameters
and Model Validation for further discussion of this).
所有的模型操作步骤都会⾃动被应⽤。注意这⾥为了简单起⻅,我们将模型预测应⽤到了训练它的数据上;这也是为什么模型能完美的预
测结果的原因(参⻅超参数和模型验证)。
For some examples of Scikit-Learn pipelines in action, see the following section on naive Bayes classification, as well as
In Depth: Linear Regression, and In-Depth: Support Vector Machines.
要学习更多Scikit-Learn管道的实际例⼦,参看接下来的⼩节朴素⻉叶斯分类,以及深⼊:线性回归和深⼊:⽀持向量机。
<
超参数及模型验证 | ⽬录 | 深⼊:朴素⻉叶斯分类 >
Open in Colab
<
特征⼯程 | ⽬录 | 深⼊:线性回归 >
Open in Colab
In Depth: Naive Bayes Classification
深⼊:朴素⻉叶斯分类
The previous four sections have given a general overview of the concepts of machine learning. In this section and the
ones that follow, we will be taking a closer look at several specific algorithms for supervised and unsupervised learning,
starting here with naive Bayes classification.
前⾯四个⼩节对机器学习的概念给出了概述。本节开始,我们会进⼊到有监督学习和⽆监督学习的⼀些特定算法当中,进⾏较深⼊的介
绍。⾸先从本节的朴素⻉叶斯分类开始。
Naive Bayes models are a group of extremely fast and simple classification algorithms that are often suitable for very
high-dimensional datasets. Because they are so fast and have so few tunable parameters, they end up being very useful
as a quick-and-dirty baseline for a classification problem. This section will focus on an intuitive explanation of how naive
Bayes classifiers work, followed by a couple examples of them in action on some datasets.
朴素⻉叶斯模型是⼀组⾮常快和简单的分类算法,它们经常⽤来对⾼维度数据集进⾏分类处理。因为它们⾮常快和有⼀些可调的参数,它
们最终成为了分类问题很好⽤的临时基线⽅法。本节会聚焦在对朴素⻉叶斯分类器⼯作原理的直观介绍,然后会在不同的数据集上应⽤它
作为例⼦。
Bayesian Classification
⻉叶斯分类
Naive Bayes classifiers are built on Bayesian classification methods. These rely on Bayes's theorem, which is an
equation describing the relationship of conditional probabilities of statistical quantities. In Bayesian classification, we're
interested in finding the probability of a label given some observed features, which we can write as P (L | f eatures).
Bayes's theorem tells us how to express this in terms of quantities we can compute more directly:
朴素⻉叶斯分类建⽴在⻉叶斯分类⽅法的基础上。这些分类⽅法的基础是⻉叶斯定理,这是⼀个⽤来描述统计理论中条件概率的等式。在
⻉叶斯分类中,我们感兴趣的是在给定观测特征数据上找到⼀个标签的概率,我们写做
。⻉叶斯定理告诉我们如何使⽤
这些已知的特征量直接计算概率:
P (L | f eatures)
P (f eatures | L)P (L)
P (L | f eatures) =
P (f eatures)
If we are trying to decide between two labels—let's call them L1 and L2 —then one way to make this decision is to
compute the ratio of the posterior probabilities for each label:
如果我们尝试在两个标签中去选择,假设我们称它们为 和 ,那么做这个选择的⼀种⽅法是计算每⼀个标签的后验概率:
L1
L2
P ( L1 | f eatures)
=
P ( L2 | f eatures)
P (f eatures | L1 ) P ( L1 )
P (f eatures | L2 ) P ( L2 )
All we need now is some model by which we can compute P (f eatures | Li ) for each label. Such a model is called a
generative model because it specifies the hypothetical random process that generates the data. Specifying this
generative model for each label is the main piece of the training of such a Bayesian classifier. The general version of
such a training step is a very difficult task, but we can make it simpler through the use of some simplifying assumptions
about the form of this model.
因此我们所需要的就是⼀个能够计算每⼀个标签的
值的模型。这个模型被称为⽣成模型,因为它指定了产⽣数据的假设
随机过程。对于训练⻉叶斯分类器来说,为每个标签找到这样的通⽤模型是最主要的步骤。获得这种训练步骤的通⽤版本是很困难的,但
是我们能够通过使⽤关于该模型的假设来简化这项任务。
P (f eatures | Li )
This is where the "naive" in "naive Bayes" comes in: if we make very naive assumptions about the generative model for
each label, we can find a rough approximation of the generative model for each class, and then proceed with the
Bayesian classification. Different types of naive Bayes classifiers rest on different naive assumptions about the data, and
we will examine a few of these in the following sections.
这就是“朴素⻉叶斯”中的“朴素”的由来:如果我们对通⽤模型中的每个标签作出⾮常朴素的假设,我们就可以找到通⽤模型中每个标签的⼤
概分布,然后进⾏⻉叶斯分类。不同的朴素⻉叶斯分类器取决于对数据不同的朴素假设上,我们在本节后续内容中会介绍它们中的⼀部
分。
We begin with the standard imports:
⾸先是需要⽤到的包:
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
Gaussian Naive Bayes
⾼斯朴素⻉叶斯
Perhaps the easiest naive Bayes classifier to understand is Gaussian naive Bayes. In this classifier, the assumption is
that data from each label is drawn from a simple Gaussian distribution. Imagine that you have the following data:
朴素⻉叶斯分类器中最容易理解的也许就是⾼斯朴素⻉叶斯。这个分类器假定每个标签的数据都服从简单正态分布。例如你有如下数据:
In [2]: from sklearn.datasets import make_blobs
X, y = make_blobs(100, 2, centers=2, random_state=2, cluster_std=1.5)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='RdBu');
One extremely fast way to create a simple model is to assume that the data is described by a Gaussian distribution with
no covariance between dimensions. This model can be fit by simply finding the mean and standard deviation of the points
within each label, which is all you need to define such a distribution. The result of this naive Gaussian assumption is
shown in the following figure:
创建⼀个简单模型的最快速⽅法就是假定数据服从⼀个两个维度之间没有协⽅差的正态分布。这个模型可以通过简单的寻找每个标签中点
的均值和标准差来拟合,你只需要定义这个分布即可。⾼斯朴素假设的结果显⽰在下图中:
(run code in Appendix to generate image)
附录中⽣成图像的代码
The ellipses here represent the Gaussian generative model for each label, with larger probability toward the center of the
ellipses. With this generative model in place for each class, we have a simple recipe to compute the likelihood
P (f eatures | L1 ) for any data point, and thus we can quickly compute the posterior ratio and determine which label is
the most probable for a given point.
上图中的椭圆表⽰每个标签的⾼斯⽣成模型,越接近椭圆中⼼位置具有越⼤的概率。有了每个分类的⽣成模型后,我们就能简单的计算每
⼀个点的概率
,也就是后验概率,然后找到哪个标签在给定数据点上具有最⼤的概率。
P (f eatures | L1 )
This procedure is implemented in Scikit-Learn's sklearn.naive_bayes.GaussianNB estimator:
这个过程在Scikit-Learn中实现成了 sklearn.naive_bayes.GaussianNB 评估器:
In [3]: from sklearn.naive_bayes import GaussianNB
model = GaussianNB()
model.fit(X, y);
Now let's generate some new data and predict the label:
现在让我们创建⼀些新数据,然后预测标签:
In [4]: rng = np.random.RandomState(0)
Xnew = [-6, -14] + [14, 18] * rng.rand(2000, 2)
ynew = model.predict(Xnew)
Now we can plot this new data to get an idea of where the decision boundary is:
下⾯我们将新数据点绘制在图上,你能看到分类判定的边界位置:
In [5]: plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='RdBu')
lim = plt.axis()
plt.scatter(Xnew[:, 0], Xnew[:, 1], c=ynew, s=20, cmap='RdBu', alpha=0.1)
plt.axis(lim);
We see a slightly curved boundary in the classifications—in general, the boundary in Gaussian naive Bayes is quadratic.
我们看到分类之间的边界是有点弯曲的,因为通常来说,⾼斯朴素⻉叶斯的边界是⼆次曲线。
A nice piece of this Bayesian formalism is that it naturally allows for probabilistic classification, which we can compute
using the predict_proba method:
这种⻉叶斯分类⽅法的⼀个好处是它天然⽀持概率分类,我们可以通过 predict_proba 计算每个分类的概率:
In [6]: yprob = model.predict_proba(Xnew)
yprob[-8:].round(2)
Out[6]: array([[0.89, 0.11],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[1. , 0. ],
[0. , 1. ],
[0.15, 0.85]])
The columns give the posterior probabilities of the first and second label, respectively. If you are looking for estimates of
uncertainty in your classification, Bayesian approaches like this can be a useful approach.
上⾯结果中的两列分别给出了两个标签的后验概率。如果你在寻找你分类中的不确定性的话,⻉叶斯⽅法能提供有效的判断依据。
Of course, the final classification will only be as good as the model assumptions that lead to it, which is why Gaussian
naive Bayes often does not produce very good results. Still, in many cases—especially as the number of features
becomes large—this assumption is not detrimental enough to prevent Gaussian naive Bayes from being a useful method.
当然最终分类结果最多只能达到模型的假定情况,这表明⾼斯朴素⻉叶斯⽅法常常不会产⽣⾮常好的结果。但是在很多情况下,特别是当
特征数量变得很⼤时,这个假定并不会导致⾼斯朴素⻉叶斯⽅法完全失去意义。
Multinomial Naive Bayes
多项式朴素⻉叶斯
The Gaussian assumption just described is by no means the only simple assumption that could be used to specify the
generative distribution for each label. Another useful example is multinomial naive Bayes, where the features are
assumed to be generated from a simple multinomial distribution. The multinomial distribution describes the probability of
observing counts among a number of categories, and thus multinomial naive Bayes is most appropriate for features that
represent counts or count rates.
前⾯描述的⾼斯假设不是唯⼀的简单假设可以⽤来为每个标签产⽣⽣成分布。另⼀个有⽤的⽅法是多项式朴素⻉叶斯,这个⽅法假定数据
的特征是从⼀个简单的多项式分布中⽣成的。多项式分布描述了在⼀些分组中观察到的计数的概率,因此多项式朴素⻉叶斯对于表达计数
或计数的⽐例之类的特征是最合适的。
The idea is precisely the same as before, except that instead of modeling the data distribution with the best-fit Gaussian,
we model the data distribuiton with a best-fit multinomial distribution.
这⾥的原理和前⾯是⼀样的,只是不是使⽤正态分布来拟合数据模型,⽽是使⽤多项式分布来拟合数据模型。
Example: Classifying Text
例⼦:分类⽂字
One place where multinomial naive Bayes is often used is in text classification, where the features are related to word
counts or frequencies within the documents to be classified. We discussed the extraction of such features from text in
Feature Engineering; here we will use the sparse word count features from the 20 Newsgroups corpus to show how we
might classify these short documents into categories.
多项式朴素⻉叶斯经常被⽤到的场合是⽂字分类,因为这个场景下的特征是单词的计数或者⽂档中单词出现的频率。我们在特征⼯程⼀节
中介绍过在⽂本中提取这样的特征的⽅法;这⾥我们会使⽤20个新闻组的语料库提取出来的稀疏单词计数特征来展⽰将这些短⽂档分类的
⽅法。
Let's download the data and take a look at the target names:
让我们下载这个数据然后查看⼀下⽬标分类的名称:
In [7]: from sklearn.datasets import fetch_20newsgroups
data = fetch_20newsgroups()
data.target_names
Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)
Out[7]: ['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'misc.forsale',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
For simplicity here, we will select just a few of these categories, and download the training and testing set:
这⾥为了简化,我们仅选择其中部分分类,然后载⼊训练集和测试集:
In [8]: categories = ['talk.religion.misc', 'soc.religion.christian',
'sci.space', 'comp.graphics']
train = fetch_20newsgroups(subset='train', categories=categories)
test = fetch_20newsgroups(subset='test', categories=categories)
Here is a representative entry from the data:
下⾯展⽰部分数据:
In [9]: print(train.data[5])
From: dmcgee@uluhe.soest.hawaii.edu (Don McGee)
Subject: Federal Hearing
Originator: dmcgee@uluhe
Organization: School of Ocean and Earth Science and Technology
Distribution: usa
Lines: 10
Fact or rumor....? Madalyn Murray O'Hare an atheist who eliminated the
use of the bible reading and prayer in public schools 15 years ago is now
going to appear before the FCC with a petition to stop the reading of the
Gospel on the airways of America. And she is also campaigning to remove
Christmas programs, songs, etc from the public schools. If it is true
then mail to Federal Communications Commission 1919 H Street Washington DC
20054 expressing your opposition to her request. Reference Petition number
2493.
In order to use this data for machine learning, we need to be able to convert the content of each string into a vector of
numbers. For this we will use the TF-IDF vectorizer (discussed in Feature Engineering), and create a pipeline that
attaches it to a multinomial naive Bayes classifier:
为了要将这个数据集应⽤到机器学习上,我们需要将数据中的每个字符串内容转换为数字的向量。我们使⽤TF-IDF来实现向量化(参⻅特
征⼯程),然后创建⼀个管道操作将⼀个多项式朴素⻉叶斯分类器连接进来:
In [10]: from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
model = make_pipeline(TfidfVectorizer(), MultinomialNB())
With this pipeline, we can apply the model to the training data, and predict labels for the test data:
我们可以将这个管道应⽤到训练集上,然后在测试集上去进⾏标签预测:
In [11]: model.fit(train.data, train.target)
labels = model.predict(test.data)
Now that we have predicted the labels for the test data, we can evaluate them to learn about the performance of the
estimator. For example, here is the confusion matrix between the true and predicted labels for the test data:
有了对测试数据预测的标签之后,我们可以对评估器的性能作出判断。例如下⾯展⽰了预测标签和实际标签之间的混淆矩阵:
In [12]: from sklearn.metrics import confusion_matrix
mat = confusion_matrix(test.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
xticklabels=train.target_names, yticklabels=train.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');
Evidently, even this very simple classifier can successfully separate space talk from computer talk, but it gets confused
between talk about religion and talk about Christianity. This is perhaps an expected area of confusion!
从上图看出,即便是这么简单的分类器也能成功的将宇宙学讨论和计算机科学讨论内容区分开,但是它在将宗教讨论和基督教讨论区分的
时候遇到了困难。因为可能这是⼀个本来就容易混淆的领域。
The very cool thing here is that we now have the tools to determine the category for any string, using the predict()
method of this pipeline. Here's a quick utility function that will return the prediction for a single string:
我们现在有了⼀个模型来对任何字符串进⾏分类检测了,⾮常酷对不对,只需要在这个管道对象上使⽤ predict() ⽅法即可。下⾯我们
创建⼀个简单的⼯具函数来对任何字符串输⼊返回标签预测的输出结果:
In [13]: def predict_category(s, train=train, model=model):
pred = model.predict([s])
return train.target_names[pred[0]]
Let's try it out:
赶快来试⼀下:
In [14]: predict_category('sending a payload to the ISS')
Out[14]: 'sci.space'
In [15]: predict_category('discussing islam vs atheism')
Out[15]: 'soc.religion.christian'
In [16]: predict_category('determining the screen resolution')
Out[16]: 'comp.graphics'
Remember that this is nothing more sophisticated than a simple probability model for the (weighted) frequency of each
word in the string; nevertheless, the result is striking. Even a very naive algorithm, when used carefully and trained on a
large set of high-dimensional data, can be surprisingly effective.
请记住这⾥做的事情仅是对字符串中每个单词的(加权)出现频率⽣成了⼀个概率模型⽽已;然⽽结果却令⼈惊奇。即使⾮常朴素的算
法,只要⼩⼼使⽤,并且在⼀个⼤规模的⾼维度数据集上进⾏训练的话,也能⾮常有效。
When to Use Naive Bayes
何时使⽤朴素⻉叶斯⽅法
Because naive Bayesian classifiers make such stringent assumptions about data, they will generally not perform as well
as a more complicated model. That said, they have several advantages:
They are extremely fast for both training and prediction
They provide straightforward probabilistic prediction
They are often very easily interpretable
They have very few (if any) tunable parameters
因为朴素⻉叶斯分类器对数据进⾏了如此严格的假设,它们通常不会⽐其他复杂的模型更加有效。朴素⻉叶斯⽅法有下⾯⼏个优点:
它们⾮常快,⽆论是在训练还是预测中
它们提供了直接的概率预测
它们通常很容易解释
它们有很少的可调参数
These advantages mean a naive Bayesian classifier is often a good choice as an initial baseline classification. If it
performs suitably, then congratulations: you have a very fast, very interpretable classifier for your problem. If it does not
perform well, then you can begin exploring more sophisticated models, with some baseline knowledge of how well they
should perform.
这些特点导致朴素⻉叶斯分类器经常被作为初始化的基线分类标准。如果它性能很好,恭喜:你的问题已经有了⼀个⾮常快速很容易解释
的分类模型了。如果它的性能不如⼈意,那么你可以开始尝试更加复杂的模型,然后将朴素⻉叶斯分类器的性能结果作为标准来对新的模
型进⾏评判。
Naive Bayes classifiers tend to perform especially well in one of the following situations:
When the naive assumptions actually match the data (very rare in practice)
For very well-separated categories, when model complexity is less important
For very high-dimensional data, when model complexity is less important
朴素⻉叶斯分类器在下⾯的⼀些情况下通常能够特别良好的⼯作:
当朴素假定能够拟合数据时(实践中⾮常少⻅)
对于数据本⾝分类就已经很清晰的情况,此时模型复杂度并不⼗分重要
对于数据维度⾮常多的情况,此时模型复杂度并不⼗分重要
The last two points seem distinct, but they actually are related: as the dimension of a dataset grows, it is much less likely
for any two points to be found close together (after all, they must be close in every single dimension to be close overall).
This means that clusters in high dimensions tend to be more separated, on average, than clusters in low dimensions,
assuming the new dimensions actually add information. For this reason, simplistic classifiers like naive Bayes tend to
work as well or better than more complicated classifiers as the dimensionality grows: once you have enough data, even a
simple model can be very powerful.
后两点看起来是独⽴的因素,但是实际上它们是关联的:当数据集的维度增加时,两个数据点⾮常接近的情况是⾮常少⻅的(毕竟它们要
在每个维度都接近才能互相接近)。这意味着⾼纬度中的分类相对于低维度数据,如果新增的维度确实增加了数据的信息量(特征)的
话,⾼维度数据点会倾向于出现在更不同的位置。因此像朴素⻉叶斯这样的简单分类器在数据维度增加情况下可能会⽐复杂分类器⼯作的
更好:⼀旦你有了⾜够的数据,哪怕是简单的模型也能⾮常强⼤。
<
特征⼯程 | ⽬录 | 深⼊:线性回归 >
Open in Colab
<
深⼊:朴素⻉叶斯分类 | ⽬录 | 深⼊:⽀持向量机 >
Open in Colab
In Depth: Linear Regression
深⼊:线性回归
Just as naive Bayes (discussed earlier in In Depth: Naive Bayes Classification) is a good starting point for classification
tasks, linear regression models are a good starting point for regression tasks. Such models are popular because they can
be fit very quickly, and are very interpretable. You are probably familiar with the simplest form of a linear regression model
(i.e., fitting a straight line to data) but such models can be extended to model more complicated data behavior.
就像朴素⻉叶斯(前⾯在深⼊:朴素⻉叶斯分类中讨论过)是分类任务合适的⼊⻔课⼀样,线性回归模型是回归任务的⼊⻔课。这种模型
因为它能够快速的训练拟合以及⾮常容易解释⽽流⾏。你可能已经了解了线性回归模型的简单形式(例如让⼀条直线拟合到数据上),但
是这样的模型也能够扩展到更加复杂的数据上。
In this section we will start with a quick intuitive walk-through of the mathematics behind this well-known problem, before
seeing how before moving on to see how linear models can be generalized to account for more complicated patterns in
data.
We begin with the standard imports:
本节中我们⾸先快速直观的介绍这个著名问题背后的数学基础,然后再讨论如何将这些线性模型泛化到适应更复杂的数据模式上。
我们先载⼊需要的包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
Simple Linear Regression
简单线性回归
We will start with the most familiar linear regression, a straight-line fit to data. A straight-line fit is a model of the form
y = ax + b
where a is commonly known as the slope, and b is commonly known as the intercept.
我们先从最熟悉的线性回归模型开始,⽤⼀条直线拟合数据。⼀条拟合直线的模型具有下⾯的数学形式
其中的 通常被称为斜率,⽽ 通常被成为截距。
y = ax + b
a
b
Consider the following data, which is scattered about a line with a slope of 2 and an intercept of -5:
下⾯的数据是⼀些随机散落在⼀条斜率为2截距为-5的直线附近的点:
In [2]: rng = np.random.RandomState(1)
x = 10 * rng.rand(50)
y = 2 * x - 5 + rng.randn(50)
plt.scatter(x, y);
We can use Scikit-Learn's LinearRegression estimator to fit this data and construct the best-fit line:
我们可以使⽤Scikit-Learn的 LinearRegression 评估器来拟合这些数据然后得到⼀条最佳拟合直线:
In [3]: from sklearn.linear_model import LinearRegression
model = LinearRegression(fit_intercept=True)
model.fit(x[:, np.newaxis], y)
xfit = np.linspace(0, 10, 1000)
yfit = model.predict(xfit[:, np.newaxis])
plt.scatter(x, y)
plt.plot(xfit, yfit);
The slope and intercept of the data are contained in the model's fit parameters, which in Scikit-Learn are always marked
by a trailing underscore. Here the relevant parameters are coef_ and intercept_ :
数据的斜率和截距可以在模型拟合参数中找到,在Scikit-Learn中总是使⽤下划线后缀来表⽰。这⾥相关的参数是 coef_ 和
intercept_ :
In [4]: print("Model slope:
", model.coef_[0])
print("Model intercept:", model.intercept_)
Model slope:
2.0272088103606953
Model intercept: -4.998577085553204
We see that the results are very close to the inputs, as we might hope.
正如我们所料,可以看到结果⾮常接近预期。
The LinearRegression estimator is much more capable than this, however—in addition to simple straight-line fits, it
can also handle multidimensional linear models of the form
y = a0 + a1 x1 + a2 x2 + ⋯
where there are multiple x values. Geometrically, this is akin to fitting a plane to points in three dimensions, or fitting a
hyper-plane to points in higher dimensions.
评估器能做的远不⽌于此,除了简单的直线拟合外,它还能处理多维线性模型的形式
这⾥有多个 值。⼏何上,这等同于在三维空间间使⽤⼀个平⾯拟合数据,或在更⾼维空间中使⽤超平⾯拟合数据。
LinearRegression
y = a0 + a1 x1 + a2 x2 + ⋯
x
The multidimensional nature of such regressions makes them more difficult to visualize, but we can see one of these fits
in action by building some example data, using NumPy's matrix multiplication operator:
这样的回归具有多维的本质,因此令它们⽐较难以可视化,但我们可以构造⼀些样例数据来查看这样的拟合,这⾥使⽤了NumPy的矩阵乘
法操作:
In [5]: rng = np.random.RandomState(1)
X = 10 * rng.rand(100, 3)
y = 0.5 + np.dot(X, [1.5, -2., 1.])
model.fit(X, y)
print(model.intercept_)
print(model.coef_)
0.500000000000012
[ 1.5 -2.
1. ]
Here the y data is constructed from three random x values, and the linear regression recovers the coefficients used to
construct the data.
这⾥ 值是由三个随机 值构建的,⽽线性回归恢复了⽤来构建数据的斜率。
y
x
In this way, we can use the single LinearRegression estimator to fit lines, planes, or hyperplanes to our data. It still
appears that this approach would be limited to strictly linear relationships between variables, but it turns out we can relax
this as well.
使⽤这种⽅法,我们可以使⽤单个 LinearRegression 评估器拟合直线、平⾯或超平⾯到数据上。⽬前为⽌这种⽅法看起来都限制在变
量之间的线性关联上,但是实际上它还能完成更多的⼯作。
Basis Function Regression
基本函数回归
One trick you can use to adapt linear regression to nonlinear relationships between variables is to transform the data
according to basis functions. We have seen one version of this before, in the PolynomialRegression pipeline used
in Hyperparameters and Model Validation and Feature Engineering. The idea is to take our multidimensional linear
model:
y = a0 + a1 x1 + a2 x2 + a3 x3 + ⋯
and build the x1 , x2 , x3 , and so on, from our single-dimensional input x . That is, we let xn = fn (x) , where fn () is
some function that transforms our data.
将线性回归应⽤在变量之间的⾮线性关系的⼀个技巧是,将数据通过基本函数进⾏转换。我们在超参数和模型验证和特征⼯程中已经看到
过多项式回归 PolynomialRegression 管道操作中已经看到这个技巧的例⼦。这个⽅法是将⼀维的输⼊数据使⽤多维线性模型
来建⽴
等。即我们令
其中的 是⽤来转换数据的函数。
y = a0 + a1 x1 + a2 x2 + a3 x3 + ⋯
x1 , x2 , x3
xn = fn (x)
fn ()
For example, if fn (x) = xn , our model becomes a polynomial regression:
2
y = a0 + a1 x + a2 x
3
+ a3 x
+⋯
Notice that this is still a linear model—the linearity refers to the fact that the coefficients an never multiply or divide each
other. What we have effectively done is taken our one-dimensional x values and projected them into a higher dimension,
so that a linear fit can fit more complicated relationships between x and y.
例如,如果令
,我们的模型就会变成⼀个多项式回归:
注意这⾥模型仍然是线性的,线性的意思是指模型中的斜率 没有互相进⾏乘法或除法操作。这⾥起作⽤的是我们将⼀维的 值投射到了
更⾼的维度上,这样我们的线性模型就能拟合 和 之间更加复杂的联系。
n
fn (x) = x
2
y = a0 + a1 x + a2 x
3
+ a3 x
+⋯
an
x
x
y
Polynomial basis functions
多项式基本函数
This polynomial projection is useful enough that it is built into Scikit-Learn, using the PolynomialFeatures
transformer:
这种多项式投射如此有⽤,所以Scikit-Learn內建了实现它的⽅法,就是 PolynomialFeatures 转换:
In [6]: from sklearn.preprocessing import PolynomialFeatures
x = np.array([2, 3, 4])
poly = PolynomialFeatures(3, include_bias=False)
poly.fit_transform(x[:, None])
Out[6]: array([[ 2., 4., 8.],
[ 3., 9., 27.],
[ 4., 16., 64.]])
We see here that the transformer has converted our one-dimensional array into a three-dimensional array by taking the
exponent of each value. This new, higher-dimensional data representation can then be plugged into a linear regression.
我们看到上例中使⽤这个转换器我们对每个值求幂将⼀维数组变成了三维数组。这个新的⾼维数据表⽰能应⽤到线性回归中。
As we saw in Feature Engineering, the cleanest way to accomplish this is to use a pipeline. Let's make a 7th-degree
polynomial model in this way:
正如我们在特征⼯程中看到的,实现这个任务的最优雅犯法是使⽤管道。这⾥我们创建⼀个7阶的多项式模型:
In [7]: from sklearn.pipeline import make_pipeline
poly_model = make_pipeline(PolynomialFeatures(7),
LinearRegression())
With this transform in place, we can use the linear model to fit much more complicated relationships between x and y.
For example, here is a sine wave with noise:
有了这样的转换⽅式,我们可以使⽤线性模型来拟合复杂得多的 和 的关系。例如像下⾯的带有噪⾳的正弦波:
x
y
In [8]: rng = np.random.RandomState(1)
x = 10 * rng.rand(50)
y = np.sin(x) + 0.1 * rng.randn(50)
poly_model.fit(x[:, np.newaxis], y)
yfit = poly_model.predict(xfit[:, np.newaxis])
plt.scatter(x, y)
plt.plot(xfit, yfit);
Our linear model, through the use of 7th-order polynomial basis functions, can provide an excellent fit to this non-linear
data!
这个例⼦中,我们通过使⽤7阶多项式函数,使得我们的线性模型能够异常良好的拟合到这个⾮线性数据上。
Gaussian basis functions
⾼斯基本函数
Of course, other basis functions are possible. For example, one useful pattern is to fit a model that is not a sum of
polynomial bases, but a sum of Gaussian bases. The result might look something like the following figure:
当然还有其他可⽤的基本函数。例如可以通过⾼斯函数叠加⽽不是多项式叠加来拟合模型。结果可能如下图所⽰:
附录中⽣成图像的代码
The shaded regions in the plot are the scaled basis functions, and when added together they reproduce the smooth curve
through the data. These Gaussian basis functions are not built into Scikit-Learn, but we can write a custom transformer
that will create them, as shown here and illustrated in the following figure (Scikit-Learn transformers are implemented as
Python classes; reading Scikit-Learn's source is a good way to see how they can be created):
上图中阴影部分是基本函数的覆盖范围,当这些阴影叠加在⼀起时就会产⽣上⾯光滑的拟合曲线。Scikit-Learn中没有內建这些⾼斯基本函
数,但我们可以写⼀个⾃定义的转换器来构造它们,就像下⾯代码和图表展⽰那样(Scikit-Learn的转换器是使⽤Python类实现的;阅读
Scikit-Learn的源代码是理解它们创建的好⽅法):
In [9]: from sklearn.base import BaseEstimator, TransformerMixin
class GaussianFeatures(BaseEstimator, TransformerMixin):
"""
"""
对⼀维数据进⾏均匀分布⾼斯转换
def __init__(self, N, width_factor=2.0):
self.N = N
self.width_factor = width_factor
@staticmethod
def _gauss_basis(x, y, width, axis=None):
arg = (x - y) / width
return np.exp(-0.5 * np.sum(arg ** 2, axis))
def fit(self, X, y=None):
#
N
self.centers_ = np.linspace(X.min(), X.max(), self.N)
self.width_ = self.width_factor * (self.centers_[1] - self.centers_[0])
return self
沿着数据范围创建均匀分布的 个中⼼点
def transform(self, X):
return self._gauss_basis(X[:, :, np.newaxis], self.centers_,
self.width_, axis=1)
gauss_model = make_pipeline(GaussianFeatures(20),
LinearRegression())
gauss_model.fit(x[:, np.newaxis], y)
yfit = gauss_model.predict(xfit[:, np.newaxis])
plt.scatter(x, y)
plt.plot(xfit, yfit)
plt.xlim(0, 10);
We put this example here just to make clear that there is nothing magic about polynomial basis functions: if you have
some sort of intuition into the generating process of your data that makes you think one basis or another might be
appropriate, you can use them as well.
我们举这个例⼦只是为了表⽰多项式函数并不特殊:如果你对数据的⽣成⽅式有了什么灵感,你也可以使⽤它对应的函数来尝试。
Regularization
正则化
The introduction of basis functions into our linear regression makes the model much more flexible, but it also can very
quickly lead to over-fitting (refer back to Hyperparameters and Model Validation for a discussion of this). For example, if
we choose too many Gaussian basis functions, we end up with results that don't look so good:
将基本函数引⼊线性回归令我们的模型更加灵活,但是它很容易导致过拟合(参⻅超参数和模型验证中的讨论)。例如如果我们选择了太
多的⾼斯函数,产⽣的结果就不太可靠了:
In [10]: model = make_pipeline(GaussianFeatures(30),
LinearRegression())
model.fit(x[:, np.newaxis], y)
plt.scatter(x, y)
plt.plot(xfit, model.predict(xfit[:, np.newaxis]))
plt.xlim(0, 10)
plt.ylim(-1.5, 1.5);
With the data projected to the 30-dimensional basis, the model has far too much flexibility and goes to extreme values
between locations where it is constrained by data. We can see the reason for this if we plot the coefficients of the
Gaussian bases with respect to their locations:
通过将数据投射到30维的空间上,该模型太过于灵活以⾄于当处于间隔距离较⼤的点之间的位置时候,会拟合成很极端的数据值。我们可
以将⾼斯函数的系数也绘制在图表中,就可以看到原因:
In [11]: def basis_plot(model, title=None):
fig, ax = plt.subplots(2, sharex=True)
model.fit(x[:, np.newaxis], y)
ax[0].scatter(x, y)
ax[0].plot(xfit, model.predict(xfit[:, np.newaxis]))
ax[0].set(xlabel='x', ylabel='y', ylim=(-1.5, 1.5))
if title:
ax[0].set_title(title)
ax[1].plot(model.steps[0][1].centers_,
model.steps[1][1].coef_)
ax[1].set(xlabel='basis location',
ylabel='coefficient',
xlim=(0, 10))
model = make_pipeline(GaussianFeatures(30), LinearRegression())
basis_plot(model)
The lower panel of this figure shows the amplitude of the basis function at each location. This is typical over-fitting
behavior when basis functions overlap: the coefficients of adjacent basis functions blow up and cancel each other out.
We know that such behavior is problematic, and it would be nice if we could limit such spikes expliticly in the model by
penalizing large values of the model parameters. Such a penalty is known as regularization, and comes in several forms.
下⾯的图展⽰了基本函数在每个位置的振幅。这是当使⽤基本函数叠加的典型过拟合情况:邻近的基本函数的系数互相叠加到波峰和波
⾕。这种情形是错误的,如果我们能在模型中限制这样的尖刺能解决这个问题,通过在模型参数⼤数值的情况下进⾏惩罚可以实现这个⽬
标。这样的惩罚被成为正则化,它有⼏种形式。
Ridge regression (L2 Regularization)
岭回归( 正则化)
L2
Perhaps the most common form of regularization is known as ridge regression or L2 regularization, sometimes also
called Tikhonov regularization. This proceeds by penalizing the sum of squares (2-norms) of the model coefficients; in
this case, the penalty on the model fit would be
N
2
P = α ∑ θn
n=1
where α is a free parameter that controls the strength of the penalty. This type of penalized model is built into ScikitLearn with the Ridge estimator:
最常⽤的正则化⽅式被称为岭回归或 正则化,有的时候也被叫做Tikhonov正则化。这个过程通过对模型系数的平⽅和(2-范数)进⾏乘
法;在这个例⼦中,模型的乘法是
L2
N
2
P = α ∑ θn
其中 是控制乘法⼒度的参数。这类的惩罚模型內建在Scikit-Learn中 Ridge 评估器中:
n=1
α
In [12]: from sklearn.linear_model import Ridge
model = make_pipeline(GaussianFeatures(30), Ridge(alpha=0.1))
basis_plot(model, title='Ridge Regression')
The α parameter is essentially a knob controlling the complexity of the resulting model. In the limit α → 0 , we recover
the standard linear regression result; in the limit α → ∞, all model responses will be suppressed. One advantage of
ridge regression in particular is that it can be computed very efficiently—at hardly more computational cost than the
original linear regression model.
参数是⽤来控制模型复杂度的关键开关。极限情况
时,恢复到标准线性回归结果;极限情况
压缩。岭回归的⼀⼤优点是它能⾮常有效的计算,基本没有产⽣⽐原始线性回归模型更⼤的计算消耗。
α
α → 0
α → ∞
时,所有模型的响应都会被
Lasso regression (L1 regularization)
Lasso
算法回归( 正则化)
L1
Another very common type of regularization is known as lasso, and involves penalizing the sum of absolute values (1norms) of regression coefficients:
N
P = α ∑ |θn |
n=1
Though this is conceptually very similar to ridge regression, the results can differ surprisingly: for example, due to
geometric reasons lasso regression tends to favor sparse models where possible: that is, it preferentially sets model
coefficients to exactly zero.
另⼀个常⽤的正则化类型被称为lasso,通过惩罚回归系数绝对值和(1-范数)来实现:
N
P = α ∑ |θn |
虽然这在概念上⾮常类似岭回归,但是结果却⼤不相同:例如因为⼏何原因lasso回归更适合稀疏模型,即它倾向于将模型系数设置为0。
n=1
We can see this behavior in duplicating the ridge regression figure, but using L1-normalized coefficients:
我们可以从下⾯的图中看到这个特点,这⾥将岭回归改成了使⽤L1正则化系数:
译者注:下⾯代码添加了 Lasso 正则化线性回归模型参数tol,以避免产⽣⼀个不收敛的警告。
In [13]: from sklearn.linear_model import Lasso
model = make_pipeline(GaussianFeatures(30), Lasso(alpha=0.001, tol=0.01))
basis_plot(model, title='Lasso Regression')
With the lasso regression penalty, the majority of the coefficients are exactly zero, with the functional behavior being
modeled by a small subset of the available basis functions. As with ridge regularization, the α parameter tunes the
strength of the penalty, and should be determined via, for example, cross-validation (refer back to Hyperparameters and
Model Validation for a discussion of this).
使⽤了lasso回归惩罚,⼤部分的系数都变成了0,也就是只有⼩部分的基本函数在模型中产⽣了作⽤。就像岭回归正则化, 参数调整惩
罚的强度,这个参数应该通过诸如交叉验证(参⻅超参数和模型验证中的讨论)来确定。
α
Example: Predicting Bicycle Traffic
例⼦:预测⾃⾏⻋流量
As an example, let's take a look at whether we can predict the number of bicycle trips across Seattle's Fremont Bridge
based on weather, season, and other factors. We have seen this data already in Working With Time Series.
我们来看⼀个例⼦,试图从天⽓、季节和其他因素中对西雅图费利蒙⼤桥的⾃⾏⻋交通流量数据进⾏预测。我们已经在在时间序列上操作
⼀节中使⽤过这个数据。
In this section, we will join the bike data with another dataset, and try to determine the extent to which weather and
seasonal factors—temperature, precipitation, and daylight hours—affect the volume of bicycle traffic through this corridor.
Fortunately, the NOAA makes available their daily weather station data (I used station ID USW00024233) and we can
easily use Pandas to join the two data sources. We will perform a simple linear regression to relate weather and other
information to bicycle counts, in order to estimate how a change in any one of these parameters affects the number of
riders on a given day.
本节中,我们会将⾃⾏⻋数据与另外⼀个数据集联合起来,然后从中找到哪些天⽓和季节因素,⽐⽅说温度、降⾬和⽇照时间,会影响到
这条交通要道⾃⾏⻋流量数据。幸运的是美国国家海洋和⼤⽓管理局NOAA公开了每天⽓象站数据(作者使⽤的是⽓象站ID
USW00024233的数据),我们可以使⽤Pandas很容易地联合两个数据集。然后我们使⽤简单的线性回归来拟合相关的天⽓以及其他因素
和⾃⾏⻋数量,以此来估计给定⼀天的任何其中⼀个参数改变对骑⾏者数量的影响。
In particular, this is an example of how the tools of Scikit-Learn can be used in a statistical modeling framework, in which
the parameters of the model are assumed to have interpretable meaning. As discussed previously, this is not a standard
approach within machine learning, but such interpretation is possible for some models.
特别这是在统计模型框架中使⽤Scikit-Learn⼯具的例⼦,其中的模型参数被认为是有可解释的含义的。正如之前讨论的,这不是机器学期
的标准⽅法,但是对于⼀些模型来说这样的解释是存在的。
Let's start by loading the two datasets, indexing by date:
让我们⾸先载⼊两个数据集,使⽤⽇期进⾏索引:
In [14]: # !curl -o FremontBridge.csv https://data.seattle.gov/api/views/65db-xm6k/rows.csv?accessType=DOWNLO
AD
In [15]: import pandas as pd
counts = pd.read_csv('data/FremontBridge.csv', index_col='Date', parse_dates=True)
weather = pd.read_csv('data/BicycleWeather.csv', index_col='DATE', parse_dates=True)
Next we will compute the total daily bicycle traffic, and put this in its own dataframe:
然后我们计算每天⾃⾏⻋的总流量,把这个数据放进它⾃⼰的DataFrame中:
In [16]: daily = counts.resample('d').sum()
daily['Total'] = daily.sum(axis=1)
daily = daily[['Total']] #
移除其他列
We saw previously that the patterns of use generally vary from day to day; let's account for this in our data by adding
binary columns that indicate the day of the week:
我们之前看到⾃⾏⻋流量随着星期天数⽽发⽣不同变化;因此让我们将这点也考虑进来,为这个数据集增加7个布尔值的列表⽰星期天数:
In [17]: days = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
for i in range(7):
daily[days[i]] = (daily.index.dayofweek == i).astype(float)
Similarly, we might expect riders to behave differently on holidays; let's add an indicator of this as well:
类似的,我们也期望骑⼿们在节⽇会有不同习惯;让我们将这点也考虑进来,加⼊⼀个标识列:
In [18]: from pandas.tseries.holiday import USFederalHolidayCalendar
cal = USFederalHolidayCalendar()
holidays = cal.holidays('2012', '2016')
daily = daily.join(pd.Series(1, index=holidays, name='holiday'))
daily['holiday'].fillna(0, inplace=True)
We also might suspect that the hours of daylight would affect how many people ride; let's use the standard astronomical
calculation to add this information:
我们同样猜测⽇照时间也会影响多少⼈骑⾃⾏⻋;我们使⽤标准天⽂计算来增加这个信息:
译者注:下⾯代码使⽤标准库的 datetime 代替了 pandas.datetime ,以避免⼀个过时的警告。
In [19]: from datetime import datetime
def hours_of_daylight(date, axis=23.44, latitude=47.61):
"""
计算给定⽇期的⽇照时间
axis 23.44 ⻩⾚夹⻆
latitude 47.61 西雅图纬度
"""
# 2000年12⽉21⽇是冬⾄⽇,⽇照时间最短
days = (date - datetime(2000, 12, 21)).days
m = (1. - np.tan(np.radians(latitude))
* np.tan(np.radians(axis) * np.cos(days * 2 * np.pi / 365.25)))
return 24. * np.degrees(np.arccos(1 - np.clip(m, 0, 2))) / 180.
daily['daylight_hrs'] = list(map(hours_of_daylight, daily.index))
daily[['daylight_hrs']].plot()
plt.ylim(8, 17)
Out[19]: (8.0, 17.0)
We can also add the average temperature and total precipitation to the data. In addition to the inches of precipitation, let's
add a flag that indicates whether a day is dry (has zero precipitation):
我们也可以增加平均⽓温和总降⾬量数据。除了单位为英⼨的降⾬量列外,我们再增加⼀列标志表⽰当天是否⼲燥(降⾬量为0):
⽓温单位是 摄⽒度,求平均值
In [20]: #
0.1
weather['TMIN'] /= 10
weather['TMAX'] /= 10
weather['Temp (C)'] = 0.5 * (weather['TMIN'] + weather['TMAX'])
降⾬量单位是 毫⽶,转换为英⼨
#
0.1
weather['PRCP'] /= 254
weather['dry day'] = (weather['PRCP'] == 0).astype(int)
daily = daily.join(weather[['PRCP', 'Temp (C)', 'dry day']])
Finally, let's add a counter that increases from day 1, and measures how many years have passed. This will let us
measure any observed annual increase or decrease in daily crossings:
最后,让我们增加⼀列计数器从第⼀天开始计数,然后转换成经过的年的⼩数数值。该列会在每年进⾏循环:
In [21]: daily['annual'] = (daily.index - daily.index[0]).days / 365.
Now our data is in order, and we can take a look at it:
现在我们总算准备好了数据,我们看⼀下:
In [22]: daily.head()
Out[22]:
Total
Mon
Tue
Wed
Thu
Fri
Sat
Sun
holiday
daylight_hrs
PRCP
Temp (C)
dry day
annual
2012-10-03
7042.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
11.277359
0.0
13.35
1
0.000000
2012-10-04
6950.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
11.219142
0.0
13.60
1
0.002740
2012-10-05
6296.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
11.161038
0.0
15.30
1
0.005479
2012-10-06
4012.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
11.103056
0.0
15.85
1
0.008219
2012-10-07
4284.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
11.045208
0.0
15.85
1
0.010959
Date
With this in place, we can choose the columns to use, and fit a linear regression model to our data. We will set
fit_intercept = False , because the daily flags essentially operate as their own day-specific intercepts:
有了数据后,我们可以选择使⽤哪些列来让线性回归模型进⾏拟合。我们设置 fit_intercept=False ,因为每天的数据都有着那⼀天
⾃⼰的截距:
移除所有有空值的⾏
In [23]: #
daily.dropna(axis=0, how='any', inplace=True)
⽤来拟合模型的列包括星期⼏、⽇照⼩时数、降⽔量、是否有⾬、⽓温、该天的年计数
#
column_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun', 'holiday',
'daylight_hrs', 'PRCP', 'dry day', 'Temp (C)', 'annual']
X = daily[column_names]
y = daily['Total']
model = LinearRegression(fit_intercept=False)
model.fit(X, y)
daily['predicted'] = model.predict(X)
Finally, we can compare the total and predicted bicycle traffic visually:
最终我们将预测的⾃⾏⻋交通流量和实际总量进⾏⽐较绘制图表:
In [24]: daily[['Total', 'predicted']].plot(alpha=0.5);
It is evident that we have missed some key features, especially during the summer time. Either our features are not
complete (i.e., people decide whether to ride to work based on more than just these) or there are some nonlinear
relationships that we have failed to take into account (e.g., perhaps people ride less at both high and low temperatures).
Nevertheless, our rough approximation is enough to give us some insights, and we can take a look at the coefficients of
the linear model to estimate how much each feature contributes to the daily bicycle count:
很明显我们遗失了⼀些关键的特征,特别是在夏天的时候。或者我们的特征不完整(如决定⼈们是否骑⾏的因素不⽌上述那些特征)或者
数据之间具有⾮线性的关系我们并未考虑进来(如⼈们在⾼温和低温的情况下都会减少骑⾏)。⽆论如何,我们这个粗糙的估计给了我们
⼀些内在解释,我们可以查看这个线性模型的系数,从中得到每个特征是如何影响每天⾃⾏⻋总量的:
In [25]: params = pd.Series(model.coef_, index=X.columns)
params
Out[25]: Mon
1031.250152
Tue
1138.417420
Wed
1241.952231
Thu
1033.464822
Fri
430.231290
Sat
-1949.104025
Sun
-1925.447365
holiday
-2214.497205
daylight_hrs
240.062223
PRCP
-1389.481290
dry day
1031.715058
Temp (C)
135.081658
annual
37.916145
dtype: float64
These numbers are difficult to interpret without some measure of their uncertainty. We can compute these uncertainties
quickly using bootstrap resamplings of the data:
这些数字如果没有⼀种对它们不确定性的度量⽅式的话很难解读。我们可以使⽤对数据的重采样来快速的计算这些不确定性:
In [26]: from sklearn.utils import resample
np.random.seed(1)
err = np.std([model.fit(*resample(X, y)).coef_
for i in range(1000)], 0)
With these errors estimated, let's again look at the results:
估计误差后,我们看⼀下结果:
In [27]: print(pd.DataFrame({'effect': params.round(0),
'error': err.round(0)}))
effect
Mon
1031.0
Tue
1138.0
Wed
1242.0
Thu
1033.0
Fri
430.0
Sat
-1949.0
Sun
-1925.0
holiday
-2214.0
daylight_hrs
240.0
PRCP
-1389.0
dry day
1032.0
Temp (C)
135.0
annual
38.0
error
283.0
266.0
269.0
276.0
261.0
260.0
275.0
478.0
31.0
175.0
103.0
10.0
109.0
We first see that there is a relatively stable trend in the weekly baseline: there are many more riders on weekdays than
on weekends and holidays. We see that for each additional hour of daylight, 129 ± 9 more people choose to ride; a
temperature increase of one degree Celsius encourages 65 ± 4 people to grab their bicycle; a dry day means an average
of 548 ± 33 more riders, and each inch of precipitation means 665 ± 62 more people leave their bike at home. Once all
these effects are accounted for, we see a modest increase of 27 ± 18 new daily riders each year.
⾸先看到的是每周相对稳定的变化趋势:显然⼯作⽇⽐周末的骑⾏者要多得多。如果每天⽇照时间多⼀个⼩时,就会多出240.0 ± 31.0个骑
⾏者;⽓温升⾼⼀摄⽒度会多出135.0 ± 10.0个骑⾏者;晴天意味着会多出1032.0 ± 103.0个骑⾏者;⽽每多⼀英⼨降⾬意味着会有1389.0
± 175.0个⼈决定将⾃⾏⻋留在家。⼀旦所有因素都计算在内,我们发现每年同⼀天会平均多出38.0 ± 109.0个骑⾏者。
译者注:上述数据根据译者的计算结果进⾏了修改。
Our model is almost certainly missing some relevant information. For example, nonlinear effects (such as effects of
precipitation and cold temperature) and nonlinear trends within each variable (such as disinclination to ride at very cold
and very hot temperatures) cannot be accounted for in this model. Additionally, we have thrown away some of the finergrained information (such as the difference between a rainy morning and a rainy afternoon), and we have ignored
correlations between days (such as the possible effect of a rainy Tuesday on Wednesday's numbers, or the effect of an
unexpected sunny day after a streak of rainy days). These are all potentially interesting effects, and you now have the
tools to begin exploring them if you wish!
我们的模型基本可以肯定遗漏了⼀些相关的信息。例如,⾮线性效果(⽐⽅说降⽔量和低⽓温的共同作⽤)和每个变量的⾮线性趋势(⽐
⽅说在⾮常热和⾮常冷的天⽓下骑⻋的缩减量),这个模型都没有计算在内。除此之外,我们还抛弃了⼀些细颗粒度的信息(例如下⾬早
晨和下⾬下午的区别),⽽且我们还忽略了连续天数之间的关联(⽐⽅说预报周三下⾬结果周⼆就下⾬了或者是连续⾬天后的⼀个意料外
的晴天)。这些都是潜在有趣的效应,并且你现在已经有了能够进⼀步探索它们的⼯具了。
<
深⼊:朴素⻉叶斯分类 | ⽬录 | 深⼊:⽀持向量机 >
Open in Colab
<
深⼊:线性回归 | ⽬录 | 深⼊:决策树和随机森林 >
Open in Colab
In-Depth: Support Vector Machines
深⼊:⽀持向量机
Support vector machines (SVMs) are a particularly powerful and flexible class of supervised algorithms for both
classification and regression. In this section, we will develop the intuition behind support vector machines and their use in
classification problems.
⽀持向量机(SVMs)是有监督学习算法中既能进⾏分类⼜能进⾏回归的特别强⼤灵活的⼯具。本节中,我们会介绍⽀持向量机背后的机
制以及它们在分类问题中的应⽤。
We begin with the standard imports:
⾸先导⼊我们需要的包:
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
设置
样式输出图表
#
Seaborn
import seaborn as sns; sns.set()
Motivating Support Vector Machines
⾛进⽀持向量机
As part of our disussion of Bayesian classification (see In Depth: Naive Bayes Classification), we learned a simple model
describing the distribution of each underlying class, and used these generative models to probabilistically determine
labels for new points. That was an example of generative classification; here we will consider instead discriminative
classification: rather than modeling each class, we simply find a line or curve (in two dimensions) or manifold (in multiple
dimensions) that divides the classes from each other.
在朴素⻉叶斯分类中(参⻅深⼊:朴素⻉叶斯分类),我们学习了⼀个简单模型,⽤于描述每个底层分类的分布情况,并使⽤这些⽣成模
型来预测新数据点的概率标签的⽅法。那是⽣成分类的⼀个例⼦;本⼩节我们不考虑使⽤判别式分类:与其对每个类别进⾏建模,我们试
图简单的寻找到⼀条曲线(⼆维空间)或流形(多维空间)能将每个类别区分出来。
As an example of this, consider the simple case of a classification task, in which the two classes of points are well
separated:
作为⼀个例⼦,考虑下⾯的分类的简单任务,图中两种类别的点已经清晰的分开了:
译者注:下⾯代码去掉了过时的samples_generator模块以避免警告。
In [2]: from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=50, centers=2,
random_state=0, cluster_std=0.60)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn');
A linear discriminative classifier would attempt to draw a straight line separating the two sets of data, and thereby create
a model for classification. For two dimensional data like that shown here, this is a task we could do by hand. But
immediately we see a problem: there is more than one possible dividing line that can perfectly discriminate between the
two classes!
⼀个线性判别分类器会试图在两个分类数据间画出⼀条直线,通过这样创建⼀个分类模型。对于像上⾯⼀样的⼆维数据,这个任务可以⼿
⼯完成。但是我们⽴刻就会碰到问题:这⾥存在多条可能的直线能完美的划分两个分类。
We can draw them as follows:
我们可以画出如下三条直线:
In [3]: xfit = np.linspace(-1, 3.5)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plt.plot([0.6], [2.1], 'x', color='red', markeredgewidth=2, markersize=10)
for m, b in [(1, 0.65), (0.5, 1.6), (-0.2, 2.9)]:
plt.plot(xfit, m * xfit + b, '-k')
plt.xlim(-1, 3.5);
These are three very different separators which, nevertheless, perfectly discriminate between these samples. Depending
on which you choose, a new data point (e.g., the one marked by the "X" in this plot) will be assigned a different label!
Evidently our simple intuition of "drawing a line between classes" is not enough, and we need to think a bit deeper.
上图中有三条⾮常不同的分割线,但是都能完美的区分这些样本。取决于你选择了哪条直线,新数据点(例如图中标记为"X"的点)会被判
定为不同的标签。显然简单的“画⼀条线分类”的简单直觉是不够的,我们需要更加深⼊地考虑这个问题。
Support Vector Machines: Maximizing the Margin
⽀持向量机:最⼤化间距
Support vector machines offer one way to improve on this. The intuition is this: rather than simply drawing a zero-width
line between the classes, we can draw around each line a margin of some width, up to the nearest point. Here is an
example of how this might look:
⽀持向量机提供了⼀个⽅法来改进这个问题。这⾥的原理是:与其简单画⼀条0宽度的线来分类,我们可以每条线上画出⼀个有宽度的间
距,直⾄最近的点为⽌。下⾯是⼀个例⼦:
In [4]: xfit = np.linspace(-1, 3.5)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
for m, b, d in [(1, 0.65, 0.33), (0.5, 1.6, 0.55), (-0.2, 2.9, 0.2)]:
yfit = m * xfit + b
plt.plot(xfit, yfit, '-k')
plt.fill_between(xfit, yfit - d, yfit + d, edgecolor='none',
color='#AAAAAA', alpha=0.4)
plt.xlim(-1, 3.5);
In support vector machines, the line that maximizes this margin is the one we will choose as the optimal model. Support
vector machines are an example of such a maximum margin estimator.
在⽀持向量机中,拥有最⼤化间距的线就是我们需要选择的那个最优化模型。⽀持向量机就是这样的最⼤化间距评估器。
Fitting a support vector machine
训练⽀持向量机
Let's see the result of an actual fit to this data: we will use Scikit-Learn's support vector classifier to train an SVM model
on this data. For the time being, we will use a linear kernel and set the C parameter to a very large number (we'll
discuss the meaning of these in more depth momentarily).
下⾯我们来看看使⽤这个数据训练⽀持向量机模型的实际结果:我们会在这些数据上使⽤Scikit-Learn⽀持向量机分类器来训练⼀个SVM模
型。⽬前我们先使⽤⼀个线性的核并且将 C 参数设置为⾮常⼤的数值(我们⻢上会深度讨论这些概念的含义)。
⽀持向量分类器
In [5]: from sklearn.svm import SVC #
model = SVC(kernel='linear', C=1E10)
model.fit(X, y)
Out[5]: SVC(C=10000000000.0, break_ties=False, cache_size=200, class_weight=None,
coef0=0.0, decision_function_shape='ovr', degree=3, gamma='scale',
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
To better visualize what's happening here, let's create a quick convenience function that will plot SVM decision
boundaries for us:
要更好的可视化展⽰发⽣的事情,我们创建⼀个快速的⼯具函数来绘制SVM的边界:
In [6]: def plot_svc_decision_function(model, ax=None, plot_support=True):
"""
2D SVC
"""
if ax is None:
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
绘制
图像函数
创建⽹格来展⽰数据
#
x = np.linspace(xlim[0], xlim[1], 30)
y = np.linspace(ylim[0], ylim[1], 30)
Y, X = np.meshgrid(y, x)
xy = np.vstack([X.ravel(), Y.ravel()]).T
P = model.decision_function(xy).reshape(X.shape)
绘制边界和间距
#
ax.contour(X, Y, P, colors='k',
levels=[-1, 0, 1], alpha=0.5,
linestyles=['--', '-', '--'])
绘制⽀持向量
#
if plot_support:
ax.scatter(model.support_vectors_[:, 0],
model.support_vectors_[:, 1],
s=300, linewidth=1, facecolors='none');
ax.set_xlim(xlim)
ax.set_ylim(ylim)
In [7]: plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plot_svc_decision_function(model);
This is the dividing line that maximizes the margin between the two sets of points. Notice that a few of the training points
just touch the margin: they are indicated by the black circles in this figure. These points are the pivotal elements of this fit,
and are known as the support vectors, and give the algorithm its name. In Scikit-Learn, the identity of these points are
stored in the support_vectors_ attribute of the classifier:
这条分割线将连个分类之间的间隔最⼤化了。注意到其中某些点正好接触到边缘:可以看到上图中⿊⾊虚线穿过的点。这些点是这个模型
训练的关键元素,被称为⽀持向量,也是这个算法名称的由来。在Scikit-Learn中,这些点的数据被保存在分类器的 support_vectors_
属性中:
In [8]: model.support_vectors_
Out[8]: array([[0.44359863, 3.11530945],
[2.33812285, 3.43116792],
[2.06156753, 1.96918596]])
A key to this classifier's success is that for the fit, only the position of the support vectors matter; any points further from
the margin which are on the correct side do not modify the fit! Technically, this is because these points do not contribute
to the loss function used to fit the model, so their position and number do not matter so long as they do not cross the
margin.
这个分类器成功的关键是在拟合过程中,只有那些⽀持向量的位置才有意义;任何其他超出边缘范围的点都不会改变训练结果。技术上来
说,这是因为这些点并不会为损失函数提供任何贡献来拟合模型,所以它们不会通过边缘区域,它们的位置和数值没有意义。
We can see this, for example, if we plot the model learned from the first 60 points and first 120 points of this dataset:
可以绘制这个模型通过前60个点的拟合结果以及前120个点的拟合结果来看到这⼀点:
In [9]: def plot_svm(N=10, ax=None):
X, y = make_blobs(n_samples=200, centers=2,
random_state=0, cluster_std=0.60)
X = X[:N]
y = y[:N]
model = SVC(kernel='linear', C=1E10)
model.fit(X, y)
ax = ax or plt.gca()
ax.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
ax.set_xlim(-1, 4)
ax.set_ylim(-1, 6)
plot_svc_decision_function(model, ax)
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
for axi, N in zip(ax, [60, 120]):
plot_svm(N, axi)
axi.set_title('N = {0}'.format(N))
In the left panel, we see the model and the support vectors for 60 training points. In the right panel, we have doubled the
number of training points, but the model has not changed: the three support vectors from the left panel are still the
support vectors from the right panel. This insensitivity to the exact behavior of distant points is one of the strengths of the
SVM model.
左图中,我们看到前60个点⽣成的模型和⽀持向量。右图中有两倍数量的训练点,但是模型并未发⽣变化:左右两图中的三个⽀持向量是
相同的。这种对于远离分隔区域的点的不敏感性是SVM模型的威⼒所在。
If you are running this notebook live, you can use IPython's interactive widgets to view this feature of the SVM model
interactively:
如果你是在notebook中实时运⾏本节的话,你可以使⽤IPython的交互组件来动态展⽰SVM模型特征:
In [10]: from ipywidgets import interact, fixed
interact(plot_svm, N=[10, 200], ax=fixed(None));
Beyond linear boundaries: Kernel SVM
超出线性限制:核SVM
Where SVM becomes extremely powerful is when it is combined with kernels. We have seen a version of kernels before,
in the basis function regressions of In Depth: Linear Regression. There we projected our data into higher-dimensional
space defined by polynomials and Gaussian basis functions, and thereby were able to fit for nonlinear relationships with a
linear classifier.
当SVM与核组合之后,它会变得异常强⼤。我们前⾯已经看到⼀个核的版本,就在深⼊:线性回归中介绍过的基本函数回归。那个例⼦中
我们将数据使⽤多项式和⾼斯函数投射到⾼维度空间中,然后就能使⽤线性分类器来拟合⾮线性的关系。
In SVM models, we can use a version of the same idea. To motivate the need for kernels, let's look at some data that is
not linearly separable:
在SVM模型中,我们可以使⽤相同的思想。为了让我们看到核的作⽤,使⽤下⾯⾮线性分割的数据:
译者注:下⾯代码去掉了过时的samples_generator模块以避免警告。
In [11]: from sklearn.datasets import make_circles
X, y = make_circles(100, factor=.1, noise=.1)
clf = SVC(kernel='linear').fit(X, y)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plot_svc_decision_function(clf, plot_support=False);
It is clear that no linear discrimination will ever be able to separate this data. But we can draw a lesson from the basis
function regressions in In Depth: Linear Regression, and think about how we might project the data into a higher
dimension such that a linear separator would be sufficient. For example, one simple projection we could use would be to
compute a radial basis function centered on the middle clump:
很明显没有线性分类器能够将这些数据点分开。但是我们可以从深⼊:线性回归⼀节中的基本函数回归类推过来,如果将数据投射到更⾼
的维度,线性分类器就可以达到划分数据的⽬标。例如下⾯使⽤的以中央的数据群为中⼼的径向基函数:
In [12]: r = np.exp(-(X ** 2).sum(1))
We can visualize this extra data dimension using a three-dimensional plot—if you are running this notebook live, you will
be able to use the sliders to rotate the plot:
可以使⽤三维图表将这个转换后的数据可视化出来,如果我们使⽤的是notebook交互模式,甚⾄还可以使⽤滑块旋转这个图表:
In [13]: from mpl_toolkits import mplot3d
def plot_3D(elev=30, azim=30, X=X, y=y):
ax = plt.subplot(projection='3d')
ax.scatter3D(X[:, 0], X[:, 1], r, c=y, s=50, cmap='autumn')
ax.view_init(elev=elev, azim=azim)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('r')
interact(plot_3D, elev=[-90, 90], azip=(-180, 180),
X=fixed(X), y=fixed(y));
We can see that with this additional dimension, the data becomes trivially linearly separable, by drawing a separating
plane at, say, r=0.7.
然后我们可以看到有了额外的维度后,数据变得线性可分,⽐⽅说我们可以在r=0.7的位置画出⼀条分割线。
Here we had to choose and carefully tune our projection: if we had not centered our radial basis function in the right
location, we would not have seen such clean, linearly separable results. In general, the need to make such a choice is a
problem: we would like to somehow automatically find the best basis functions to use.
这个例⼦中我们需要仔细的选择和调整我们的投射⽅式:如果我们没有将径向基函数的中⼼点放置在正确的位置上,就不能找到这样清晰
的线性分割线出来。通常如何进⾏选择会是⼀个问题:我们希望有⼀种⾃动选择最佳基函数的⽅法。
One strategy to this end is to compute a basis function centered at every point in the dataset, and let the SVM algorithm
sift through the results. This type of basis function transformation is known as a kernel transformation, as it is based on a
similarity relationship (or kernel) between each pair of points.
⼀个实现的⽅法是在数据集中的每个数据点作为中⼼点计算基函数,然后让SVM算法帮我们从结果中筛选出好的基函数。这种基函数转换
被称为核转换,因为它建⽴在每⼀对数据点之间相似的关系(或称为核)的基础之上。
A potential problem with this strategy—projecting N points into N dimensions—is that it might become very
computationally intensive as N grows large. However, because of a neat little procedure known as the kernel trick, a fit
on kernel-transformed data can be done implicitly—that is, without ever building the full N -dimensional representation of
the kernel projection! This kernel trick is built into the SVM, and is one of the reasons the method is so powerful.
这种⽅法的潜在问题是,将 个点投射到 个维度上是⾮常消耗计算资源的,特别是 增⼤的情况下。但是因为存在⼀个被称为核技巧的
过程,在核转换的数据上的拟合可以被隐式完成,也就是说不需要构建完整的 维核投射数据就可以完成训练。这个和技巧內建在SVM
中,也是这个算法如此强⼤的原因之⼀。
N
N
N
N
In Scikit-Learn, we can apply kernelized SVM simply by changing our linear kernel to an RBF (radial basis function)
kernel, using the kernel model hyperparameter:
在Scikit-Learn中我们要应⽤核化的SVM,只需要简单将线性的核改为RBF(径向基函数)核,通过设置模型的 kernel 超参数即可:
In [14]: clf = SVC(kernel='rbf', C=1E6, gamma='auto')
clf.fit(X, y)
Out[14]: SVC(C=1000000.0, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
In [15]: plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plot_svc_decision_function(clf)
plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],
s=300, lw=1, facecolors='none');
Using this kernelized support vector machine, we learn a suitable nonlinear decision boundary. This kernel transformation
strategy is used often in machine learning to turn fast linear methods into fast nonlinear methods, especially for models in
which the kernel trick can be used.
使⽤这个核化的⽀持向量机,我们得到了⼀条合适的⾮线性决定边界。这种核转换策略经常在机器学习中被使⽤在将线性⽅法转变为快速
的⾮线性⽅法的场合,特别适合能运⽤核技巧的模型中。
Tuning the SVM: Softening Margins
调优:软化边缘
SVM
Our discussion thus far has centered around very clean datasets, in which a perfect decision boundary exists. But what if
your data has some amount of overlap? For example, you may have data like this:
我们⽬前讨论集中在⾮常⼲净的数据集上,也就是存在着完美的决定边界。如果数据中存在⼀些重叠的话会怎么样?如下⾯看到的数据:
In [16]: X, y = make_blobs(n_samples=100, centers=2,
random_state=0, cluster_std=1.2)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn');
To handle this case, the SVM implementation has a bit of a fudge-factor which "softens" the margin: that is, it allows
some of the points to creep into the margin if that allows a better fit. The hardness of the margin is controlled by a tuning
parameter, most often known as C . For very large C , the margin is hard, and points cannot lie in it. For smaller C , the
margin is softer, and can grow to encompass some points.
要处理这种情况,SVM提供了⼀些附加因素⽤于软化边缘:意思就是,它允许⼀些数据点潜⼊到边缘区域,如果这样能达到更好的拟合效
果的话。边缘的硬度被⼀个称为 的可调参数控制。如果 的值很⼤,边缘是硬的,也就是数据点⽆法进⼊边缘区域。如果 的值⽐较
⼩,边缘是软的,能够蔓延到点之外。
C
C
C
The plot shown below gives a visual picture of how a changing C parameter affects the final fit, via the softening of the
margin:
下⾯的图表展⽰了使⽤边缘软化技术,调整了 参数之后影响到最终拟合的情况:
C
In [17]: X, y = make_blobs(n_samples=100, centers=2,
random_state=0, cluster_std=0.8)
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
for axi, C in zip(ax, [10.0, 0.1]):
model = SVC(kernel='linear', C=C).fit(X, y)
axi.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='autumn')
plot_svc_decision_function(model, axi)
axi.scatter(model.support_vectors_[:, 0],
model.support_vectors_[:, 1],
s=300, lw=1, facecolors='none');
axi.set_title('C = {0:.1f}'.format(C), size=14)
The optimal value of the C parameter will depend on your dataset, and should be tuned using cross-validation or a
similar procedure (refer back to Hyperparameters and Model Validation).
最优的 值取决于你的数据集,应该通过交叉验证或者类似⽅法(参⻅超参数和模型验证)来调整。
C
Example: Face Recognition
例⼦:⼈脸识别
As an example of support vector machines in action, let's take a look at the facial recognition problem. We will use the
Labeled Faces in the Wild dataset, which consists of several thousand collated photos of various public figures. A fetcher
for the dataset is built into Scikit-Learn:
作为⽀持向量机的⼀个实际例⼦,让我们来看⼀下⼈脸识别问题。我们使⽤的是⼀个标注好的数据集,其中包含着⼏千张公众⼈物的脸部
照⽚。Scikit-Learn內建了获取数据集的⽅法:
In [18]: from sklearn.datasets import fetch_lfw_people
faces = fetch_lfw_people(min_faces_per_person=60)
print(faces.target_names)
print(faces.images.shape)
['Ariel Sharon' 'Colin Powell' 'Donald Rumsfeld' 'George W Bush'
'Gerhard Schroeder' 'Hugo Chavez' 'Junichiro Koizumi' 'Tony Blair']
(1348, 62, 47)
Let's plot a few of these faces to see what we're working with:
我们将其中⼀些脸谱画出来看⼀下:
In [19]: fig, ax = plt.subplots(3, 5)
for i, axi in enumerate(ax.flat):
axi.imshow(faces.images[i], cmap='bone')
axi.set(xticks=[], yticks=[],
xlabel=faces.target_names[faces.target[i]])
Each image contains [62×47] or nearly 3,000 pixels. We could proceed by simply using each pixel value as a feature, but
often it is more effective to use some sort of preprocessor to extract more meaningful features; here we will use a
principal component analysis (see In Depth: Principal Component Analysis) to extract 150 fundamental components to
feed into our support vector machine classifier. We can do this most straightforwardly by packaging the preprocessor and
the classifier into a single pipeline:
每张图像含有[62×47]或者说将近3000像素点。我们将每个像素点作为数据集的⼀个特征,但是通常来说更有效的做法是采⽤⼀些预处理器
来提取图像中有意义的特征;这⾥我们将采⽤主成分分析(参⻅深⼊:主成分分析)来提取图像中150个基础成分,然后输⼊到⽀持向量
机分类器中。最简单完成这项任务的⽅式是将预处理器和分类器连接成⼀个管道:
In [20]: from sklearn.svm import SVC
from sklearn.decomposition import PCA as RandomizedPCA
from sklearn.pipeline import make_pipeline
pca = RandomizedPCA(n_components=150, whiten=True, random_state=42)
svc = SVC(kernel='rbf', class_weight='balanced')
model = make_pipeline(pca, svc)
For the sake of testing our classifier output, we will split the data into a training and testing set:
为了验证我们分类器的性能,我们将数据分成训练集和测试集:
In [21]: from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(faces.data, faces.target,
random_state=42)
Finally, we can use a grid search cross-validation to explore combinations of parameters. Here we will adjust C (which
controls the margin hardness) and gamma (which controls the size of the radial basis function kernel), and determine the
best model:
最后,我们可以使⽤⽹格查找交叉验证来检验模型超参数的组合。这⾥我们会调整 C (⽤来控制边缘硬度)和 gamma (⽤来控制径向基
函数核的⼤⼩),从中找到最佳模型:
In [22]: from sklearn.model_selection import GridSearchCV
param_grid = {'svc__C': [1, 5, 10, 50],
'svc__gamma': [0.0001, 0.0005, 0.001, 0.005]}
grid = GridSearchCV(model, param_grid, cv=3)
%time grid.fit(Xtrain, ytrain)
print(grid.best_params_)
CPU times: user 1min 18s, sys: 26.7 s, total: 1min 44s
Wall time: 20 s
{'svc__C': 5, 'svc__gamma': 0.005}
The optimal values fall toward the middle of our grid; if they fell at the edges, we would want to expand the grid to make
sure we have found the true optimum.
最优值应该落在⽹格的中央位置;如果它们落在边缘位置,我们应该考虑扩⼤⽹格来确保我们找到了最优值。
Now with this cross-validated model, we can predict the labels for the test data, which the model has not yet seen:
现在我们有了经过交叉验证后的模型,可以⽤来预测测试数据上的标签了,这些数据是模型从来没有接触过的:
In [23]: model = grid.best_estimator_
yfit = model.predict(Xtest)
Let's take a look at a few of the test images along with their predicted values:
下⾯我们看⼀下测试照⽚以及它们对应的预测标签值:
In [24]: fig, ax = plt.subplots(4, 6)
for i, axi in enumerate(ax.flat):
axi.imshow(Xtest[i].reshape(62, 47), cmap='bone')
axi.set(xticks=[], yticks=[])
axi.set_ylabel(faces.target_names[yfit[i]].split()[-1],
color='black' if yfit[i] == ytest[i] else 'red')
fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14);
Out of this small sample, our optimal estimator mislabeled only a single face (Bush’s face in the bottom row was
mislabeled as Blair). We can get a better sense of our estimator's performance using the classification report, which lists
recovery statistics label by label:
在这个⼩样本中,我们优化过的评估器仅标错了⼀个⼈的脸(最下⾯⼀⾏布什的脸被错误的标记成了布莱尔)。我们可以使⽤分类报告更
好的得到评估器的性能估计,报告中我们对每个标签都进验证得出统计数据:
译者注:模型预测结果与原作者结果有⼀定差别,上⾯原⽂照译,应该不影响阅读。
In [25]: from sklearn.metrics import classification_report
print(classification_report(ytest, yfit,
target_names=faces.target_names))
precision
recall
f1-score
support
Ariel Sharon
Colin Powell
Donald Rumsfeld
George W Bush
Gerhard Schroeder
Hugo Chavez
Junichiro Koizumi
Tony Blair
0.91
0.85
0.88
0.83
1.00
1.00
1.00
0.90
0.67
0.90
0.74
0.94
0.78
0.70
0.92
0.90
0.77
0.87
0.81
0.88
0.88
0.82
0.96
0.90
15
68
31
126
23
20
12
42
accuracy
macro avg
weighted avg
0.92
0.88
0.82
0.87
0.87
0.86
0.87
337
337
337
We might also display the confusion matrix between these classes:
当然我们也可以绘制这些分类之间的混淆矩阵:
In [26]: from sklearn.metrics import confusion_matrix
mat = confusion_matrix(ytest, yfit)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
xticklabels=faces.target_names,
yticklabels=faces.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');
This helps us get a sense of which labels are likely to be confused by the estimator.
上⾯的矩阵能帮助我们建⽴评估器在每个标签上性能的直观印象。
For a real-world facial recognition task, in which the photos do not come pre-cropped into nice grids, the only difference in
the facial classification scheme is the feature selection: you would need to use a more sophisticated algorithm to find the
faces, and extract features that are independent of the pixellation. For this kind of application, one good option is to make
use of OpenCV, which, among other things, includes pre-trained implementations of state-of-the-art feature extraction
tools for images in general and faces in particular.
在真实的⼈脸识别任务中,⼈脸的照⽚并没有预先裁剪成⼀个⽹格,与例⼦中的唯⼀区别在于特征选择部分:你可能需要使⽤⼀个复杂得
多的算法来找到⼈脸,并将其中像素中独⽴的特征提取出来。对于这类应⽤,OpenCV是⼀个很好的选择,因为它其中已经內建了预训练
的特征提取⼯具实现,⽆论是通⽤的图像还是⼈脸的照⽚。
Support Vector Machine Summary
⽀持向量机总结
We have seen here a brief intuitive introduction to the principals behind support vector machines. These methods are a
powerful classification method for a number of reasons:
Their dependence on relatively few support vectors means that they are very compact models, and take up very little
memory.
Once the model is trained, the prediction phase is very fast.
Because they are affected only by points near the margin, they work well with high-dimensional data—even data with
more dimensions than samples, which is a challenging regime for other algorithms.
Their integration with kernel methods makes them very versatile, able to adapt to many types of data.
本节内容对⽀持向量机背后的原理进⾏了简要直观的介绍。这些⽅法在分类任务中⾮常强⼤的原因在于:
它们仅依赖于很少的⽀持向量,这意味着它们是⾮常紧凑的模型,占⽤的内存也⾮常少。
⼀旦模型训练好了,预测阶段⾮常快速。
因为它们仅受到边缘区域附近数据点的影响,所以它们在⾼维度数据上也⼯作良好,甚⾄数据的特征维度⽐样本数还多的情况下。这
对于其他算法来说基本是⼀个具挑战性的领域。
将它们与核⽅法集成之后将会⾮常灵活强⼤,能够适⽤多种类型的数据。
However, SVMs have several disadvantages as well:
The scaling with the number of samples N is O[N 3 ] at worst, or O[N 2 ] for efficient implementations. For large
numbers of training samples, this computational cost can be prohibitive.
The results are strongly dependent on a suitable choice for the softening parameter C . This must be carefully
chosen via cross-validation, which can be expensive as datasets grow in size.
The results do not have a direct probabilistic interpretation. This can be estimated via an internal cross-validation
(see the probability parameter of SVC ), but this extra estimation is costly.
当然,SVM也有如下⼀些缺点:
算法复杂度在样本数为 的情况下,最差是
,最好实现是
。这导致训练集特别⼤的情况下,这个计算代价是⽆法承受
的。
算法的性能结果很⼤程度上依赖于边缘软化参数 的选择。这需要通过交叉验证来⼩⼼选择,在数据集增⼤时该操作也是⾮常昂贵
的。
结果没有直接的概率解释。虽然可以通过内部的交叉验证(参⻅ SVC 分类器的 probability 参数)来估计,但是要更精确的估算是
困难的。
N
O[N
3
]
O[N
2
]
C
With those traits in mind, I generally only turn to SVMs once other simpler, faster, and less tuning-intensive methods have
been shown to be insufficient for my needs. Nevertheless, if you have the CPU cycles to commit to training and crossvalidating an SVM on your data, the method can lead to excellent results.
有了上⾯的特性,作者通过仅会在其他简单快速和更少超参调节的⽅法⽆法满⾜的情况下采⽤SVM。然⽽,如果你有很好的计算资源来完
成SVM的训练和交叉验证的话,这个⽅法能提供优异的结果。
<
深⼊:线性回归 | ⽬录 | 深⼊:决策树和随机森林 >
Open in Colab
<
深⼊:⽀持向量机 | ⽬录 | 深⼊:主成分分析 >
Open in Colab
In-Depth: Decision Trees and Random Forests
深⼊:决策树和随机森林
Previously we have looked in depth at a simple generative classifier (naive Bayes; see In Depth: Naive Bayes
Classification) and a powerful discriminative classifier (support vector machines; see In-Depth: Support Vector Machines).
Here we'll take a look at motivating another powerful algorithm—a non-parametric algorithm called random forests.
Random forests are an example of an ensemble method, meaning that it relies on aggregating the results of an ensemble
of simpler estimators. The somewhat surprising result with such ensemble methods is that the sum can be greater than
the parts: that is, a majority vote among a number of estimators can end up being better than any of the individual
estimators doing the voting! We will see examples of this in the following sections. We begin with the standard imports:
前⾯我们深⼊的介绍了简单的⽣成分类器(朴素⻉叶斯,参⻅深⼊:朴素⻉叶斯分类)和强⼤的决定分类器(⽀持向量机,参⻅深⼊:⽀
持向量机)。下⾯我来看另外⼀种强⼤的算法,⼀个没有参数被称为随机森林的算法。随机森林是⼀个集成的⽅法,这意味着它是是在⼀
系列简单评估器组合的基础上建⽴的。令⼈惊奇的是,这种组合的⽅法会⽐独⽴的算法应⽤要强⼤:即通过多个评估器得到的多数票选举
结果要优于其中任何⼀个独⽴的评估器得到的结果。我们会在本节后⾯看到⼀些例⼦。⾸先还是导⼊需要的包:
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
Motivating Random Forests: Decision Trees
开始学习随机森林:决策树
Random forests are an example of an ensemble learner built on decision trees. For this reason we'll start by discussing
decision trees themselves.
随机森林是构建在决策树基础上进⾏组合学习的⼀种⽅法。因此我们先讨论⼀下决策树本⾝。
Decision trees are extremely intuitive ways to classify or label objects: you simply ask a series of questions designed to
zero-in on the classification. For example, if you wanted to build a decision tree to classify an animal you come across
while on a hike, you might construct the one shown here:
决策树是⽤来分类或者标记对象的⾮常直观的⽅法:你只需要简单的提出⼀系列设计好的问题,最终达到分类标签即可。例如,如果希望
构建⼀个⽤来分类动物的决策树,你可以构建下⾯的这棵树:
附录中⽣成图像的代码
The binary splitting makes this extremely efficient: in a well-constructed tree, each question will cut the number of options
by approximately half, very quickly narrowing the options even among a large number of classes. The trick, of course,
comes in deciding which questions to ask at each step. In machine learning implementations of decision trees, the
questions generally take the form of axis-aligned splits in the data: that is, each node in the tree splits the data into two
groups using a cutoff value within one of the features. Let's now look at an example of this.
这种⼆元的区分⽅式使得算法⾮常⾼效:在⼀个构造良好的树中,每个问题都会使得剩下的可⽤选项减半,这甚⾄在分类数量很多情况下
也能迅速的得到结果。当然这个效率取决于每⼀步设计问题的技巧。在决策树的机器学习实现中,树中的问题通常都采⽤沿着轴来分割数
据:也就是说,树中的每个节点会在数据的⼀个特征上,根据⼀个阈值⼀分为⼆。下⾯我们看⼀个例⼦。
Creating a decision tree
创建决策树
Consider the following two-dimensional data, which has one of four class labels:
考虑下⾯的⼆维数据,具有四个分类标签:
In [2]: from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=300, centers=4,
random_state=0, cluster_std=1.0)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='rainbow');
A simple decision tree built on this data will iteratively split the data along one or the other axis according to some
quantitative criterion, and at each level assign the label of the new region according to a majority vote of points within it.
This figure presents a visualization of the first four levels of a decision tree classifier for this data:
在这个数据上建⽴的简单决策树会沿着两个轴来分类数据,每⼀层的划分都会按照区域中⼤多数数据点的分类标签(多数票)来确定区域
的标签值。下⾯的图像展⽰了头四层的决策树进⾏分类的可视化过程:
附录中⽣成图像的代码
Notice that after the first split, every point in the upper branch remains unchanged, so there is no need to further
subdivide this branch. Except for nodes that contain all of one color, at each level every region is again split along one of
the two features.
上图看到第⼀层分类后,图中上部的分⽀⼀直保持不变,因此没有必要再对这个分⽀进⾏细分了。除⾮某个节点已经达到包含同⼀颜⾊的
⽬的,否则每⼀层的不同区域都是再次沿着两个特征其中之⼀对数据进⾏再次细分。
This process of fitting a decision tree to our data can be done in Scikit-Learn with the DecisionTreeClassifier
estimator:
这个决策树的拟合过程可以通过Scikit-Learn中的 DecisionTreeClassifier 评估器来实现:
In [3]: from sklearn.tree import DecisionTreeClassifier
tree = DecisionTreeClassifier().fit(X, y)
Let's write a quick utility function to help us visualize the output of the classifier:
然后我们写⼀个⼯具函数帮助我们展⽰分类器的数据可视化:
In [4]: def visualize_classifier(model, X, y, ax=None, cmap='rainbow'):
ax = ax or plt.gca()
绘制训练集数据点
#
ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=cmap,
clim=(y.min(), y.max()), zorder=3)
ax.axis('tight')
ax.axis('off')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
模型拟合
#
model.fit(X, y)
xx, yy = np.meshgrid(np.linspace(*xlim, num=200),
np.linspace(*ylim, num=200))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
填充结果区域
#
n_classes = len(np.unique(y))
contours = ax.contourf(xx, yy, Z, alpha=0.3,
levels=np.arange(n_classes + 1) - 0.5,
cmap=cmap, zorder=1)
ax.set(xlim=xlim, ylim=ylim)
Now we can examine what the decision tree classification looks like:
下⾯我们可以看⼀下决策树分类器的分类效果了:
In [5]: visualize_classifier(DecisionTreeClassifier(), X, y)
If you're running this notebook live, you can use the helpers script included in The Online Appendix to bring up an
interactive visualization of the decision tree building process:
如果你在使⽤交互式的notebook的话,你可以使⽤使⽤⼀个⼯具脚本附录中⼯具脚本来展⽰决策树动态可视化构建过程:
译者注:helpers_05_08.py⽂件第31⾏中的ax.contourf⽅法会产⽣⼀个Warning,预计新版Matplotlib会修复这个问题,该警告是⽆害的,
因此保留下来了。
可在附录中找到
In [6]: # helpers_05_08
import helpers_05_08
helpers_05_08.plot_tree_interactive(X, y);
Notice that as the depth increases, we tend to get very strangely shaped classification regions; for example, at a depth of
five, there is a tall and skinny purple region between the yellow and blue regions. It's clear that this is less a result of the
true, intrinsic data distribution, and more a result of the particular sampling or noise properties of the data. That is, this
decision tree, even at only five levels deep, is clearly over-fitting our data.
随着深度(树节点层次)增加,我们会得到⼀个⾮常奇怪的分类区域形状;如上⾯深度为5时,图像下部会出现⼀条很⾼的狭⻓紫⾊区域,
处于绿⾊和蓝⾊区域之间。从直觉上我们就可以知道这是错误的,这个结果不是来源⾃数据的内在分布特性,⽽更像是通过数据中个别的
样本或噪⾳获得的。也就是说决策树即使只有5层深度也发⽣了数据的过拟合。
Decision trees and over-fitting
决策树和过拟合
Such over-fitting turns out to be a general property of decision trees: it is very easy to go too deep in the tree, and thus to
fit details of the particular data rather than the overall properties of the distributions they are drawn from. Another way to
see this over-fitting is to look at models trained on different subsets of the data—for example, in this figure we train two
different trees, each on half of the original data:
这种过拟合是决策树经常出现的问题:很容易就会构建⼀个深度太⼤的决策树,这样的树模型会聚焦在数据的特定数据点或噪⾳之上,⽽
不是数据本⾝的分布特性之上。另外⼀种判断过拟合的⽅法是在数据不同⼦集上的训练结果,例如,下⾯两张图表⽰的是在数据集各⼀半
的数据点上训练得到的两个不同的模型:
附录中⽣成图像的代码
It is clear that in some places, the two trees produce consistent results (e.g., in the four corners), while in other places,
the two trees give very different classifications (e.g., in the regions between any two clusters). The key observation is that
the inconsistencies tend to happen where the classification is less certain, and thus by using information from both of
these trees, we might come up with a better result!
很明显的看到,在⼀些位置上,两棵树都产⽣了相同的结果(例如四个⻆附近的位置),但是在其他位置上,两个模型给出了⾮常差异的
分类结果(例如在两个分类的交界处)。这些差异⼀般会出现在分类器确定性较低的位置,因此如果我们同时使⽤这两棵树的特性的话,
可以预计得到更好的结果。
If you are running this notebook live, the following function will allow you to interactively display the fits of trees trained on
a random subset of the data:
如果使⽤notebooke交互模式,下⾯的函数能动态展⽰使⽤数据的随机⼦集训练得到的模型:
代码能在附录中找到
In [7]: # helpers_05_08
import helpers_05_08
helpers_05_08.randomized_tree_interactive(X, y)
Just as using information from two trees improves our results, we might expect that using information from many trees
would improve our results even further.
上⾯看到使⽤两棵树的信息能改善结果,我们可以预计组合使⽤更多的树的信息能够得到更好的改善结果。
Ensembles of Estimators: Random Forests
评估器合成:随机森林
This notion—that multiple overfitting estimators can be combined to reduce the effect of this overfitting—is what underlies
an ensemble method called bagging. Bagging makes use of an ensemble (a grab bag, perhaps) of parallel estimators,
each of which over-fits the data, and averages the results to find a better classification. An ensemble of randomized
decision trees is known as a random forest.
上述⽅法,即多个过拟合的评估器可以被合并来减少过拟合,被称为装袋,是⼀种团体学习的算法。装袋将⼀些并⾏的评估器组装(类似
塞到袋⼦⾥)起来,其中的每个评估器都会产⽣过拟合,然后对结果求平均来得到⼀个更好的分类。对随机决策树的组装被称为随机森
林。
This type of bagging classification can be done manually using Scikit-Learn's BaggingClassifier meta-estimator, as
shown here:
这种类型的装袋分类可以通过Scikit-Learn的 BaggingClassifier 元评估器来⼿动实现,如下例:
In [8]: from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import BaggingClassifier
tree = DecisionTreeClassifier()
bag = BaggingClassifier(tree, n_estimators=100, max_samples=0.8,
random_state=1)
bag.fit(X, y)
visualize_classifier(bag, X, y)
In this example, we have randomized the data by fitting each estimator with a random subset of 80% of the training
points. In practice, decision trees are more effectively randomized by injecting some stochasticity in how the splits are
chosen: this way all the data contributes to the fit each time, but the results of the fit still have the desired randomness.
For example, when determining which feature to split on, the randomized tree might select from among the top several
features. You can read more technical details about these randomization strategies in the Scikit-Learn documentation and
references within.
在上例中,我们通过在数据集的随机80%的数据点上拟合出100个模型。在实践中,决策树可以通过注⼊更多的随机性来选择⼦数据集以
达到更好的效果:这个⽅法中所有的数据在每次拟合过程中都会产⽣贡献,但是拟合的结果仍然具有期望的随机性。例如当决定哪个特征
来划分数据集时,随机决策树可以从前⾯⼏个特征中进⾏不同的选择。你可以在Scikit-Learn在线⽂档中督导更多这些随机策略的技术细
节。
In Scikit-Learn, such an optimized ensemble of randomized decision trees is implemented in the
RandomForestClassifier estimator, which takes care of all the randomization automatically. All you need to do is
select a number of estimators, and it will very quickly (in parallel, if desired) fit the ensemble of trees:
在Scikit-Learn中,上述的随机决策树的优化组合算法被实现在 RandomForestClassifier 评估器中,它能全⾃动地处理所有的随机情
况。你只需要设置评估器的个数,它能迅速的(根据需要进⾏并⾏计算)拟合整个森林:
In [9]: from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, random_state=0)
visualize_classifier(model, X, y);
We see that by averaging over 100 randomly perturbed models, we end up with an overall model that is much closer to
our intuition about how the parameter space should be split.
上⾯例⼦可以看到,通过在100个随机选择的模型上进⾏平均,我们能够得到⼀个更加符合我们对数据集分布的直觉模型。
Random Forest Regression
随机森林回归
In the previous section we considered random forests within the context of classification. Random forests can also be
made to work in the case of regression (that is, continuous rather than categorical variables). The estimator to use for this
is the RandomForestRegressor , and the syntax is very similar to what we saw earlier.
在前⾯内容中我们介绍了随机森林应⽤在分类场景下的⽅法。随机森林也能在回归场景中使⽤(即⾮离散的分类⽽是连续的分类)。实现
这个场景的评估器是 RandomForestRegressor ,它的语法和前⾯看到的分类语法很相似。
Consider the following data, drawn from the combination of a fast and slow oscillation:
考虑下⾯由⼀个快速震荡和缓慢震荡组合得到的数据集:
In [10]: rng = np.random.RandomState(42)
x = 10 * rng.rand(200)
def model(x, sigma=0.3):
fast_oscillation = np.sin(5 * x)
slow_oscillation = np.sin(0.5 * x)
noise = sigma * rng.randn(len(x))
return slow_oscillation + fast_oscillation + noise
y = model(x)
plt.errorbar(x, y, 0.3, fmt='o');
Using the random forest regressor, we can find the best fit curve as follows:
使⽤随机森林回归,我们能获得下⾯的最佳拟合曲线:
In [11]: from sklearn.ensemble import RandomForestRegressor
forest = RandomForestRegressor(200)
forest.fit(x[:, None], y)
xfit = np.linspace(0, 10, 1000)
yfit = forest.predict(xfit[:, None])
ytrue = model(xfit, sigma=0)
plt.errorbar(x, y, 0.3, fmt='o', alpha=0.5)
plt.plot(xfit, yfit, '-r');
plt.plot(xfit, ytrue, '-k', alpha=0.5);
Here the true model is shown in the smooth gray curve, while the random forest model is shown by the jagged red curve.
As you can see, the non-parametric random forest model is flexible enough to fit the multi-period data, without us needing
to specifying a multi-period model!
上⾯真实的数据使⽤灰⾊光滑的曲线展⽰,⽽随机森林模型使⽤红⾊锯⻮曲线展⽰。可以看到⽆参数的随机森林模型可以⾜够灵活的拟合
多周期数据,甚⾄不需要指定任何多周期模型。
Example: Random Forest for Classifying Digits
例⼦:使⽤随机森林分类⼿写数字
Earlier we took a quick look at the hand-written digits data (see Introducing Scikit-Learn). Let's use that again here to see
how the random forest classifier can be used in this context.
前⾯我们快速的介绍了⼀下⼿写数字数据(参⻅Scikit-Learn简介)。下⾯我们来看看随机森林分类器在这个场景下的应⽤。
In [12]: from sklearn.datasets import load_digits
digits = load_digits()
digits.keys()
Out[12]: dict_keys(['data', 'target', 'target_names', 'images', 'DESCR'])
To remind us what we're looking at, we'll visualize the first few data points:
展⽰前⾯若⼲数据点,⽅便我们理解⼤概的数据集情况:
设置图表
In [13]: #
fig = plt.figure(figsize=(6, 6)) #
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
图表尺⼨
绘制数字,每个数字都是 ⼤⼩
#
8x8
for i in range(64):
ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])
ax.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')
添加数字的标签
#
ax.text(0, 7, str(digits.target[i]))
We can quickly classify the digits using a random forest as follows:
然后我们使⽤随机森林来分类这些数字:
In [14]: from sklearn.model_selection import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target,
random_state=0)
model = RandomForestClassifier(n_estimators=1000)
model.fit(Xtrain, ytrain)
ypred = model.predict(Xtest)
We can take a look at the classification report for this classifier:
然后看⼀下这个分类器的分类性能报告:
In [15]: from sklearn import metrics
print(metrics.classification_report(ypred, ytest))
precision
recall
f1-score
support
0
1
2
3
4
5
6
7
8
9
1.00
0.98
0.95
0.98
0.97
0.98
1.00
1.00
0.96
0.98
0.97
0.95
1.00
0.98
1.00
0.96
1.00
0.98
0.98
0.98
0.99
0.97
0.98
0.98
0.99
0.97
1.00
0.99
0.97
0.98
38
44
42
45
37
49
52
49
47
47
accuracy
macro avg
weighted avg
0.98
0.98
0.98
0.98
0.98
0.98
0.98
450
450
450
And for good measure, plot the confusion matrix:
为了更清晰,还可以绘制混淆矩阵:
In [16]: from sklearn.metrics import confusion_matrix
mat = confusion_matrix(ytest, ypred)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False)
plt.xlabel('true label')
plt.ylabel('predicted label');
We find that a simple, untuned random forest results in a very accurate classification of the digits data.
从这个例⼦我们发现,⼀个简单未经过优化的随机森林算法,就能在⼿写数字分类上得到⾮常精确的分类结果。
Summary of Random Forests
随机森林总结
This section contained a brief introduction to the concept of ensemble estimators, and in particular the random forest – an
ensemble of randomized decision trees. Random forests are a powerful method with several advantages:
Both training and prediction are very fast, because of the simplicity of the underlying decision trees. In addition, both
tasks can be straightforwardly parallelized, because the individual trees are entirely independent entities.
The multiple trees allow for a probabilistic classification: a majority vote among estimators gives an estimate of the
probability (accessed in Scikit-Learn with the predict_proba() method).
The nonparametric model is extremely flexible, and can thus perform well on tasks that are under-fit by other
estimators.
本节中介绍了组合评估器的概念,进⼀步介绍了随机森林,⼀个随机化决策树的团队算法。随机森林由于下述优点使其成为很强⼤的⼀个
⽅法:
训练和预测都⾮常快,因为其基础决策树计算⾮常简单。并且所有的任务都能直接地并⾏化,因为每⼀棵独⽴的树都是完全⽆关的。
多棵决策树都可以进⾏概率分类:区域内的⼤多数票给出了概率分布的估计值(使⽤Scikit-Learn的 predict_proba() 实现)。
⽆参数的模型⾮常灵活,可以在其他评估器⽋拟合的情况下⼯作得很良好。
A primary disadvantage of random forests is that the results are not easily interpretable: that is, if you would like to draw
conclusions about the meaning of the classification model, random forests may not be the best choice.
随机森林最主要的缺点在于结果不容易解释:也即是说,如果你试图从分类模型中提取深层次的含义的话,随机森林可能不是最好的选
择。
<
深⼊:⽀持向量机 | ⽬录 | 深⼊:主成分分析 >
Open in Colab
<
深⼊:决策树和随机森林 | ⽬录 | 深⼊:流形学习 >
Open in Colab
In Depth: Principal Component Analysis
深⼊:主成分分析
Up until now, we have been looking in depth at supervised learning estimators: those estimators that predict labels based
on labeled training data. Here we begin looking at several unsupervised estimators, which can highlight interesting
aspects of the data without reference to any known labels.
⽬前为⽌我们深⼊了解了⼀些有监督学习评估器:这些评估器构建在标记的训练数据的基础之上。下⾯我们看⼀些⽆监督学习评估器,它
们能在没有已知标签的情况下聚焦在数据的有意义的特征上。
In this section, we explore what is perhaps one of the most broadly used of unsupervised algorithms, principal component
analysis (PCA). PCA is fundamentally a dimensionality reduction algorithm, but it can also be useful as a tool for
visualization, for noise filtering, for feature extraction and engineering, and much more. After a brief conceptual
discussion of the PCA algorithm, we will see a couple examples of these further applications.
本节中,我们将讨论⼀个可能使⽤最⼴泛的⽆监督学习算法,主成分分析(PCA)。PCA本质上是⼀个降维算法,但是它也可以作为可视
化、过滤噪⾳、特征提取和特征⼯程等⽅⾯的有⽤⼯具。在对PCA算法概念进⾏简要介绍之后,我们会看到其应⽤的⼀些场景。
We begin with the standard imports:
⾸先导⼊需要的包:
In [1]: %matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
Introducing Principal Component Analysis
主成分分析简介
Principal component analysis is a fast and flexible unsupervised method for dimensionality reduction in data, which we
saw briefly in Introducing Scikit-Learn. Its behavior is easiest to visualize by looking at a two-dimensional dataset.
Consider the following 200 points:
主成分分析是⼀个快速⽽灵活的⽆监督学习算法,主要⽤于降低数据的维度,我们在Scikit-Learn简介中已经简短地介绍过它。在下⾯的⼆
维数据集中,能更⽅便我们理解PCA⽅法。例如下⾯200个数据点:
In [2]: rng = np.random.RandomState(1)
X = np.dot(rng.rand(2, 2), rng.randn(2, 200)).T
plt.scatter(X[:, 0], X[:, 1])
plt.axis('equal');
By eye, it is clear that there is a nearly linear relationship between the x and y variables. This is reminiscent of the linear
regression data we explored in In Depth: Linear Regression, but the problem setting here is slightly different: rather than
attempting to predict the y values from the x values, the unsupervised learning problem attempts to learn about the
relationship between the x and y values.
⾁眼就能清晰观察到x和y变量之间的线性关系。这种线性关系让我们想起在深⼊:线性回归中介绍的内容,但是本节问题有⼀点不同:我
们不是希望训练⼀个可以通过x值预测y值的模型,⽽是希望模型能够学习到x和y值之间的关联。
In principal component analysis, this relationship is quantified by finding a list of the principal axes in the data, and using
those axes to describe the dataset. Using Scikit-Learn's PCA estimator, we can compute this as follows:
在主成分分析中,这种关联关系被量化成在数据中找到⼀个主要特征轴的列表,然后使⽤这些轴来描绘数据集。使⽤Scikit-Learn的 PCA
评估器,我们可以进⾏如下计算:
In [3]: from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca.fit(X)
Out[3]: PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
svd_solver='auto', tol=0.0, whiten=False)
The fit learns some quantities from the data, most importantly the "components" and "explained variance":
拟合过程从数据中学习到了⼀些定量指标,最重要的是其中的“成分”和“可解释⽅差”:
In [4]: print(pca.components_)
[[-0.94446029 -0.32862557]
[-0.32862557 0.94446029]]
In [5]: print(pca.explained_variance_)
[0.7625315 0.0184779]
To see what these numbers mean, let's visualize them as vectors over the input data, using the "components" to define
the direction of the vector, and the "explained variance" to define the squared-length of the vector:
要展⽰这些数值代表的含义,可以把它们可视化成输⼊数据上的⽮量,使⽤“成分”来确定⽮量的⽅向,“可解释⽅差”⽤来确定⽮量的⻓度:
In [6]: def draw_vector(v0, v1, ax=None):
ax = ax or plt.gca()
arrowprops=dict(arrowstyle='->',
linewidth=2,
shrinkA=0, shrinkB=0)
ax.annotate('', v1, v0, arrowprops=arrowprops)
# plot data
plt.scatter(X[:, 0], X[:, 1], alpha=0.2)
for length, vector in zip(pca.explained_variance_, pca.components_):
v = vector * 3 * np.sqrt(length)
draw_vector(pca.mean_, pca.mean_ + v)
plt.axis('equal');
These vectors represent the principal axes of the data, and the length of the vector is an indication of how "important"
that axis is in describing the distribution of the data—more precisely, it is a measure of the variance of the data when
projected onto that axis. The projection of each data point onto the principal axes are the "principal components" of the
data.
这些⽮量代表着数据的主要特征轴,⽽⽮量的⻓度代表着这个轴对于数据的分布起到了多重要的作⽤,更精确来说,这是数据被投射到这
个轴上时⽅差的度量。将每个数据点投射到主要特征轴上被称为数据的“主要成分”。
If we plot these principal components beside the original data, we see the plots shown here:
如果我们将主要成分画在原始数据旁边,会得到下图:
附录中⽣成图像的代码
This transformation from data axes to principal axes is an affine transformation, which basically means it is composed of
a translation, rotation, and uniform scaling.
这种将数据轴变换成主要特征轴的⽅法被称为仿射变换,仿射变换基本上上可以通过转换、旋转和统⼀⽐例完成。
While this algorithm to find principal components may seem like just a mathematical curiosity, it turns out to have very farreaching applications in the world of machine learning and data exploration.
虽然这个查找主成分的算法看起来就像数学理论⽽已,但是实际上它在机器学习和数据挖掘领域有着⾮常⼴泛的应⽤。
PCA as dimensionality reduction
使⽤PCA降维
Using PCA for dimensionality reduction involves zeroing out one or more of the smallest principal components, resulting
in a lower-dimensional projection of the data that preserves the maximal data variance.
使⽤PCA降维主要包括将⼀个或多个次要成分从数据中移除,从⽽获得数据的⼀个低维度的映射并保留最⼤化的数据差异。
Here is an example of using PCA as a dimensionality reduction transform:
下⾯是⼀个使⽤PCA进⾏降维转换的例⼦:
In [7]: pca = PCA(n_components=1)
pca.fit(X)
X_pca = pca.transform(X)
print("original shape:
", X.shape)
print("transformed shape:", X_pca.shape)
original shape:
(200, 2)
transformed shape: (200, 1)
The transformed data has been reduced to a single dimension. To understand the effect of this dimensionality reduction,
we can perform the inverse transform of this reduced data and plot it along with the original data:
转换后的数据被减成了⼀维。要理解降维的效果,我们可以将数据进⾏逆转换并将它们沿着原始数据的⽅向绘制在图中:
In [8]: X_new = pca.inverse_transform(X_pca)
plt.scatter(X[:, 0], X[:, 1], alpha=0.2)
plt.scatter(X_new[:, 0], X_new[:, 1], alpha=0.8)
plt.axis('equal');
The light points are the original data, while the dark points are the projected version. This makes clear what a PCA
dimensionality reduction means: the information along the least important principal axis or axes is removed, leaving only
the component(s) of the data with the highest variance. The fraction of variance that is cut out (proportional to the spread
of points about the line formed in this figure) is roughly a measure of how much "information" is discarded in this
reduction of dimensionality.
浅⾊的点代表原始数据,⽽深⾊的点是放射变换的得到的数据。上图清晰地表⽰了PCA降维的含义:沿着⾮重要维度上的信息都被移除
了,只留下了具有⾼差异性维度上的数据。被移除的差异部分(按⽐例体现在图中深⾊线的数据点分布情况中)是降维操作中多少信息被
丢弃了的粗略衡量。
This reduced-dimension dataset is in some senses "good enough" to encode the most important relationships between
the points: despite reducing the dimension of the data by 50%, the overall relationship between the data points are mostly
preserved.
降维后的数据集与原始数据点之间,在某种程度上具有“⾜够”重要关联的编码表达:虽然将数据的维度减少了⼀半,但是数据点之间的整
体联系⼤部分被保留了下来。
PCA for visualization: Hand-written digits
使⽤PCA进⾏可视化:⼿写数字
The usefulness of the dimensionality reduction may not be entirely apparent in only two dimensions, but becomes much
more clear when looking at high-dimensional data. To see this, let's take a quick look at the application of PCA to the
digits data we saw in In-Depth: Decision Trees and Random Forests.
当数据仅有两个维度时,降维的效果并不明显,但是在⾼维度数据的情况下,这个操作的威⼒就体现出来了。为了展⽰这点,让我们将
PCA应⽤在⼿写数字数据上,我们在深⼊:决策树和随机森林中已经看到过它的应⽤。
We start by loading the data:
⾸先载⼊数据:
In [9]: from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shape
Out[9]: (1797, 64)
Recall that the data consists of 8×8 pixel images, meaning that they are 64-dimensional. To gain some intuition into the
relationships between these points, we can use PCA to project them to a more manageable number of dimensions, say
two:
我们前⾯知道数据是由8x8像素的图像构成,这表⽰数据共有64个维度。要获得这些数据点之间的内在联系,我们可以使⽤PCA将它们投
射到更容易管理的维度数量上,例如2:
In [10]: pca = PCA(2) # 将64维数据投射到2维上
projected = pca.fit_transform(digits.data)
print(digits.data.shape)
print(projected.shape)
(1797, 64)
(1797, 2)
We can now plot the first two principal components of each point to learn about the data:
然后我们就可以将数据的两个主成分绘制在下图中:
In [11]: plt.scatter(projected[:, 0], projected[:, 1],
c=digits.target, edgecolor='none', alpha=0.5,
cmap=plt.cm.get_cmap('viridis', 10))
plt.xlabel('component 1')
plt.ylabel('component 2')
plt.colorbar();
Recall what these components mean: the full data is a 64-dimensional point cloud, and these points are the projection of
each data point along the directions with the largest variance. Essentially, we have found the optimal stretch and rotation
in 64-dimensional space that allows us to see the layout of the digits in two dimensions, and have done this in an
unsupervised manner—that is, without reference to the labels.
回忆⼀下这些成分的含义:完整的数据是64维的数据点组成的云,上图中的点是每个数据点投射到最⼤差异⽅向上的投射点。或者说基本
上,我们找到了通过最优的拉伸和旋转将64维数据展⽰在2维上的⽅式,并且采取的是⼀种⽆监督的⼿段,也就是没有任何的标签参考。
What do the components mean?
成分的内在涵义
We can go a bit further here, and begin to ask what the reduced dimensions mean. This meaning can be understood in
terms of combinations of basis vectors. For example, each image in the training set is defined by a collection of 64 pixel
values, which we will call the vector x :
x = [x1 , x2 , x3 ⋯ x64 ]
下⾯继续深⼊探讨⼀下,提出问题,这个降维之后的结果究竟有什么意义。我们可以通过基本⽮量的组合来理解它。例如训练集中的每张
图像都是64个像素值的集合,我们把它称为 :
x
x = [x1 , x2 , x3 ⋯ x64 ]
One way we can think about this is in terms of a pixel basis. That is, to construct the image, we multiply each element of
the vector by the pixel it describes, and then add the results together to build the image:
image(x) = x1 ⋅ (pixel 1) + x2 ⋅ (pixel 2) + x3 ⋅ (pixel 3) ⋯ x64 ⋅ (pixel 64)
⽤计算机图像基础术语来说的话,就是要构成⼀张图像,我们将上⾯的向量中的每个元素乘上它所代表的像素点,然后将这些结果加起来
就能得到图像:
image(x) = x1 ⋅ (pixel 1) + x2 ⋅ (pixel 2) + x3 ⋅ (pixel 3) ⋯ x64 ⋅ (pixel 64)
One way we might imagine reducing the dimension of this data is to zero out all but a few of these basis vectors. For
example, if we use only the first eight pixels, we get an eight-dimensional projection of the data, but it is not very reflective
of the whole image: we've thrown out nearly 90% of the pixels!
可以认为降维就是将除了需要保留的基础⽮量外的部分全部移除。例如如果我们仅使⽤前⾯8个像素,我们就能得到8维的数据投射,但是
结果并不能完整展⽰原图:因为我们丢弃了接近90%的像素。
附录中⽣成图像的代码
The upper row of panels shows the individual pixels, and the lower row shows the cumulative contribution of these pixels
to the construction of the image. Using only eight of the pixel-basis components, we can only construct a small portion of
the 64-pixel image. Were we to continue this sequence and use all 64 pixels, we would recover the original image.
上图中第⼀⾏展⽰的是单独的像素点,第⼆⾏展⽰的是对应的像素点加⼊后对累计求和结果产⽣的影响。仅使⽤8个像素点基础成分时,我
们只能构建64像素图像的⼀⼩部分。继续这个过程将所有64个像素相加后,我们就能恢复原始图像。
But the pixel-wise representation is not the only choice of basis. We can also use other basis functions, which each
contain some pre-defined contribution from each pixel, and write something like
image(x) = mean + x1 ⋅ (basis 1) + x2 ⋅ (basis 2) + x3 ⋅ (basis 3) ⋯
但是逐个像素点的表⽰⽅式不是唯⼀的选择。我们可以使⽤其他的基本函数,其中包含着每个像素对图像的预设影响值,⽐⽅说写成
image(x) = mean + x1 ⋅ (basis 1) + x2 ⋅ (basis 2) + x3 ⋅ (basis 3) ⋯
PCA can be thought of as a process of choosing optimal basis functions, such that adding together just the first few of
them is enough to suitably reconstruct the bulk of the elements in the dataset. The principal components, which act as
the low-dimensional representation of our data, are simply the coefficients that multiply each of the elements in this
series. This figure shows a similar depiction of reconstructing this digit using the mean plus the first eight PCA basis
functions:
可以想象成选择最优的基本函数,在这个函数当中只需要将头⼏项相加就能重建数据集的主要部分。主成分作为数据的低维度表⽰,
其实就是这个函数当中的头⼏项的系数。下图展⽰了使⽤平均值加上头8个PCA基本函数后重建的图像情况:
PCA
附录中⽣成图像的代码
Unlike the pixel basis, the PCA basis allows us to recover the salient features of the input image with just a mean plus
eight components! The amount of each pixel in each component is the corollary of the orientation of the vector in our twodimensional example. This is the sense in which PCA provides a low-dimensional representation of the data: it discovers
a set of basis functions that are more efficient than the native pixel-basis of the input data.
不同于图像的像素表达,PCA允许我们仅使⽤平均值加上8个成分就还原了原始图像中的显著特征。还原图中的每个像素都是我们⼆维例⼦
中⽮量叠加后的直接结果。PCA提供低维度数据表⽰⽅法的其实就是,它找到了⼀组基本函数能⽐输⼊数据的原始像素图⽅式更有效的表
⽰图像。
Choosing the number of components
选择成分的数量
A vital part of using PCA in practice is the ability to estimate how many components are needed to describe the data.
This can be determined by looking at the cumulative explained variance ratio as a function of the number of components:
在实践中使⽤PCA的⼀个重要技巧是估算需要多少个成分来描述数据。这可以通过查看可解释⽅差⽐例来决定:
In [12]: pca = PCA().fit(digits.data)
plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel('number of components')
plt.ylabel('cumulative explained variance');
This curve quantifies how much of the total, 64-dimensional variance is contained within the first N components. For
example, we see that with the digits the first 10 components contain approximately 75% of the variance, while you need
around 50 components to describe close to 100% of the variance.
这条曲线量化了在64维中前 个成分的可解释⽅差占⽐。例如我们看到使⽤前10个成分时可解释⽅差占⽐为75%,⽽成分数量需要到达50
个左右时,这个⽐例才能接近100%。
N
Here we see that our two-dimensional projection loses a lot of information (as measured by the explained variance) and
that we'd need about 20 components to retain 90% of the variance. Looking at this plot for a high-dimensional dataset
can help you understand the level of redundancy present in multiple observations.
由上图我们也看到我们的⼆维投射损失了很多的信息(由可解释⽅差衡量)我们需要⼤概20个成分才能获得90%的可解释⽅差占⽐。将这
个图应⽤在⾼维度的数据集上时,能帮助你理解不同维度情况下数据的冗余度情况。
PCA as Noise Filtering
使⽤PCA去噪
PCA can also be used as a filtering approach for noisy data. The idea is this: any components with variance much larger
than the effect of the noise should be relatively unaffected by the noise. So if you reconstruct the data using just the
largest subset of principal components, you should be preferentially keeping the signal and throwing out the noise.
对于在数据中过滤掉噪⾳来说PCA也是⼀个好的⽅法。这其中的原理是:任何具有较⼤差异的成分相对来说都不会受到噪⾳的影响。因此
如果你通过保留⼤部分主成分来重建数据集的话,应该能达到较好的去噪效果。
Let's see how this looks with the digits data. First we will plot several of the input noise-free data:
让我们来看看它在⼿写数字数据集中的表现。⾸先绘制⽆噪⾳情况下的数字:
In [13]: def plot_digits(data):
fig, axes = plt.subplots(4, 10, figsize=(10, 4),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(data[i].reshape(8, 8),
cmap='binary', interpolation='nearest',
clim=(0, 16))
plot_digits(digits.data)
Now lets add some random noise to create a noisy dataset, and re-plot it:
下⾯我们增加⼀些随机噪⾳,来创建⼀个有噪⾳的数据集,并绘制数字:
In [14]: np.random.seed(42)
noisy = np.random.normal(digits.data, 4)
plot_digits(noisy)
It's clear by eye that the images are noisy, and contain spurious pixels. Let's train a PCA on the noisy data, requesting
that the projection preserve 50% of the variance:
⾁眼可⻅图像多了⼀些随机噪⾳像素点。然后在带噪⾳的数据上使⽤PCA,保留50%的可解释⽅差:
In [15]: pca = PCA(0.50).fit(noisy)
pca.n_components_
Out[15]: 12
Here 50% of the variance amounts to 12 principal components. Now we compute these components, and then use the
inverse of the transform to reconstruct the filtered digits:
上⾯的输出告诉我们,只需要使⽤12个主成分就可以保留50%的可解释⽅差。然后我们进⾏仿射变换,先转换成主成分,再逆转换回来,
就可以重建过滤掉噪⾳后的数字:
In [16]: components = pca.transform(noisy)
filtered = pca.inverse_transform(components)
plot_digits(filtered)
This signal preserving/noise filtering property makes PCA a very useful feature selection routine—for example, rather
than training a classifier on very high-dimensional data, you might instead train the classifier on the lower-dimensional
representation, which will automatically serve to filter out random noise in the inputs.
这个保留信号或过滤噪⾳的特性使得PCA在特征选择过程中⾮常有⽤,例如,与其在⾮常⾼维度数据上训练⼀个分类器,你倒不如在数据
的低维度表⽰中去进⾏训练,这样做的好处是数据中的随机噪⾳已经被过滤掉了。
Example: Eigenfaces
例⼦:特征脸谱
Earlier we explored an example of using a PCA projection as a feature selector for facial recognition with a support vector
machine (see In-Depth: Support Vector Machines). Here we will take a look back and explore a bit more of what went into
that. Recall that we were using the Labeled Faces in the Wild dataset made available through Scikit-Learn:
前⾯章节我们探索了使⽤⽀持向量机(参⻅深⼊:⽀持向量机)来进⾏⼈脸识别,那⾥我们采⽤PCA作为特征选择器。下⾯我们来回顾⼀
下,并且进⾏更深⼊的探索。回想之前我们使⽤过的Wild数据集,Scikit-Learn已经提供了标注:
In [17]: from sklearn.datasets import fetch_lfw_people
faces = fetch_lfw_people(min_faces_per_person=60)
print(faces.target_names)
print(faces.images.shape)
['Ariel Sharon' 'Colin Powell' 'Donald Rumsfeld' 'George W Bush'
'Gerhard Schroeder' 'Hugo Chavez' 'Junichiro Koizumi' 'Tony Blair']
(1348, 62, 47)
Let's take a look at the principal axes that span this dataset. Because this is a large dataset, we will use
RandomizedPCA —it contains a randomized method to approximate the first N principal components much more
quickly than the standard PCA estimator, and thus is very useful for high-dimensional data (here, a dimensionality of
nearly 3,000). We will take a look at the first 150 components:
让我们⾸先查看⼀下这个数据集的主要成分。因为这是⼀个较⼤的数据集,我们会使⽤ RandomizedPCA ,这个⽅法会使⽤⼀个随机的⽅
法来估算数据集的前 个主成分,它⽐标准的 PCA 评估器要快的多,因此在⾼维度数据中很有⽤(本数据集中维度接近3000)。我们看
⼀下前150个主要成分:
译者注:RandomizedPCA在新版Scikit-Learn中已经过时,统⼀采⽤PCA,只是需要传递⼀个 svd_solver='randomized' 参数即可,
下⾯代码做了调整。
N
In [18]: from sklearn.decomposition import PCA as RandomizedPCA
pca = RandomizedPCA(150, svd_solver='randomized')
pca.fit(faces.data)
Out[18]: PCA(copy=True, iterated_power='auto', n_components=150, random_state=None,
svd_solver='randomized', tol=0.0, whiten=False)
In this case, it can be interesting to visualize the images associated with the first several principal components (these
components are technically known as "eigenvectors," so these types of images are often called "eigenfaces"). As you can
see in this figure, they are as creepy as they sound:
然后我们将这些提取出主成分的图像可视化出来,会得到很有趣的结果(这些成分在技术上被成为“特征向量”,因此这些图像经常被称
为“特征脸谱”)。正如你下⾯看到的,这些图像看起来很怪异:
In [19]: fig, axes = plt.subplots(3, 8, figsize=(9, 4),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(pca.components_[i].reshape(62, 47), cmap='bone')
The results are very interesting, and give us insight into how the images vary: for example, the first few eigenfaces (from
the top left) seem to be associated with the angle of lighting on the face, and later principal vectors seem to be picking
out certain features, such as eyes, noses, and lips. Let's take a look at the cumulative variance of these components to
see how much of the data information the projection is preserving:
这些结果很有意思,它们为我们提供了图像内在的差异性:例如,左上⻆开始的⼏张特征脸谱看起来像是被不同⻆度的光照之下的结果,
⽽后⾯的主要向量则似乎选择了不同的脸部特征,例如眼睛、⿐⼦和嘴唇等。然后我们看⼀下这些成分的可解释⽅差的⽐例曲线,得到脸
部信息保留⽐例的⼤概印象:
In [20]: plt.plot(np.cumsum(pca.explained_variance_ratio_))
plt.xlabel('number of components')
plt.ylabel('cumulative explained variance');
We see that these 150 components account for just over 90% of the variance. That would lead us to believe that using
these 150 components, we would recover most of the essential characteristics of the data. To make this more concrete,
we can compare the input images with the images reconstructed from these 150 components:
结果表明150个成分已经保留了超过90%的差异。于是我们确信使⽤这150个主成分,就可以还原⼤部分数据的特征了。更具体的来说,我
们可以将原始图像和通过150个成分还原的图像进⾏⽐较:
计算主成分和还原的图像
In [21]: #
pca = RandomizedPCA(150).fit(faces.data)
components = pca.transform(faces.data)
projected = pca.inverse_transform(components)
绘制结果
In [22]: #
fig, ax = plt.subplots(2, 10, figsize=(10, 2.5),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i in range(10):
ax[0, i].imshow(faces.data[i].reshape(62, 47), cmap='binary_r')
ax[1, i].imshow(projected[i].reshape(62, 47), cmap='binary_r')
ax[0, 0].set_ylabel('full-dim\ninput')
ax[1, 0].set_ylabel('150-dim\nreconstruction');
The top row here shows the input images, while the bottom row shows the reconstruction of the images from just 150 of
the ~3,000 initial features. This visualization makes clear why the PCA feature selection used in In-Depth: Support Vector
Machines was so successful: although it reduces the dimensionality of the data by nearly a factor of 20, the projected
images contain enough information that we might, by eye, recognize the individuals in the image. What this means is that
our classification algorithm needs to be trained on 150-dimensional data rather than 3,000-dimensional data, which
depending on the particular algorithm we choose, can lead to a much more efficient classification.
第⼀⾏是原始图像,第⼆⾏是通过约3000个初始特征中选出的150个主要特征还原得到的图像。上⾯的结果很清晰的解释了为什么主成分
分析选择的特征在深⼊:⽀持向量机当中应⽤得如此成功:虽然它将特征的数量减少到初始数量的5%左右,但是它投射得到的图像包含了
⾜够的信息,我们可以通过⾁眼就分辨出每张图的个体。这意味着我们的分类算法只需要在150维的数据上进⾏训练,⽽不是在3000维数
据上进⾏,这会极⼤提升分类器的性能。
Principal Component Analysis Summary
主成分分析总结
In this section we have discussed the use of principal component analysis for dimensionality reduction, for visualization of
high-dimensional data, for noise filtering, and for feature selection within high-dimensional data. Because of the versatility
and interpretability of PCA, it has been shown to be effective in a wide variety of contexts and disciplines. Given any highdimensional dataset, I tend to start with PCA in order to visualize the relationship between points (as we did with the
digits), to understand the main variance in the data (as we did with the eigenfaces), and to understand the intrinsic
dimensionality (by plotting the explained variance ratio). Certainly PCA is not useful for every high-dimensional dataset,
but it offers a straightforward and efficient path to gaining insight into high-dimensional data.
本节中我们讨论了使⽤主成分分析进⾏降维、对⾼维数据可视化、去噪和⾼维数据特征选择的⽅法。因为PCA算法的可解释性和灵活性,
它在⼤范围的情景和⽅法中展现了有效性。当⾯对⼀个⾼维数据集时,作者倾向于⾸先使⽤PCA来将数据点的关联关系可视化出来(正如
我们在⼿写数字中做的那样),然后试图找到数据中的最主要的可解释⽅差(正如我们在特征脸谱中做的那样),还有⽤来理解数据的固
有维度量(通过绘制可解释⽅差⽐率)。当然PCA也不太适⽤于⾮常⾼维度的数据集,但是它提供了⼀个直接和有效的探视⾼维度数据内
部特征的途径。
PCA's main weakness is that it tends to be highly affected by outliers in the data. For this reason, many robust variants of
PCA have been developed, many of which act to iteratively discard data points that are poorly described by the initial
components. Scikit-Learn contains a couple interesting variants on PCA, including RandomizedPCA and SparsePCA ,
both also in the sklearn.decomposition submodule. RandomizedPCA , which we saw earlier, uses a nondeterministic method to quickly approximate the first few principal components in very high-dimensional data, while
SparsePCA introduces a regularization term (see In Depth: Linear Regression) that serves to enforce sparsity of the
components.
的主要缺点在于它容易受到数据中离群值的影响。正因为此,很多更加健壮的PCA变种被开发出来,其中很多都致⼒于迭代丢弃那些
很少被初始成分影响的数据点。Scikit-Learn也包含了⼀些PCA的有趣的变种,包括 RandomizedPCA 和 SparsePCA ,它们也位于
sklearn.decomposition 包中。 RandomizedPCA 我们上⾯⽤来在⾼维数据中快速的近似的找到主要成分,⽽ SparsePCA 通过引
⼊⼀个正则项(参⻅深⼊:线性回归)来强制让成分变得稀疏。
PCA
In the following sections, we will look at other unsupervised learning methods that build on some of the ideas of PCA.
在下⾯章节中,我们还会看到其他的⼀些构建在PCA理论基础上的⽆监督学习的⽅法。
<
深⼊:决策树和随机森林 | ⽬录 | 深⼊:流形学习 >
Open in Colab
<
深⼊:主成分分析 | ⽬录 | 深⼊:k-均值聚类 >
Open in Colab
In-Depth: Manifold Learning
深⼊:流形学习
We have seen how principal component analysis (PCA) can be used in the dimensionality reduction task—reducing the
number of features of a dataset while maintaining the essential relationships between the points. While PCA is flexible,
fast, and easily interpretable, it does not perform so well when there are nonlinear relationships within the data; we will
see some examples of these below.
上⼀节已经介绍了主成分分析(PCA)⽤来进⾏降维,即减少数据集特征的数量却保留了数据点之间的关键关联。虽然PCA很灵活、快速
和容易解释,它在数据之间存在⾮线性关系的时候表现不是特别好;我们会在下⾯的⼀些例⼦中看到。
To address this deficiency, we can turn to a class of methods known as manifold learning—a class of unsupervised
estimators that seeks to describe datasets as low-dimensional manifolds embedded in high-dimensional spaces. When
you think of a manifold, I'd suggest imagining a sheet of paper: this is a two-dimensional object that lives in our familiar
three-dimensional world, and can be bent or rolled in that two dimensions. In the parlance of manifold learning, we can
think of this sheet as a two-dimensional manifold embedded in three-dimensional space.
要解决上⾯的问题,我们可以使⽤⼀类被称为流形学习的⽅法,这是⼀类⽆监督学习评估器试图使⽤低维度的流形来描述⾼纬度空间的数
据集。当提到流形时,我们可以想象⼀张纸:这是⼀个⼆维的对象,处于我们熟悉的三维世界中,还能在这个基础上被弯曲或翻卷。如果
类推到流形学习中,我们可以将这张纸看成是三维空间中的⼆维流形。
Rotating, re-orienting, or stretching the piece of paper in three-dimensional space doesn't change the flat geometry of the
paper: such operations are akin to linear embeddings. If you bend, curl, or crumple the paper, it is still a two-dimensional
manifold, but the embedding into the three-dimensional space is no longer linear. Manifold learning algorithms would
seek to learn about the fundamental two-dimensional nature of the paper, even as it is contorted to fill the threedimensional space.
在三维空间中旋转、重定位或者延展这张纸不会改变纸张的平⾯⼏何特性:这样的操作都可以归类为线性嵌⼊操作。如果你弯曲、卷曲或
者翻卷纸张,它仍然是⼀个⼆维流形,但是这些操作在三维空间中不再是线性嵌⼊操作。流形学习算法会试图找到这张纸的⼆维本质,即
使它卷曲延伸在⼀个三维空间中。
Here we will demonstrate a number of manifold methods, going most deeply into a couple techniques: multidimensional
scaling (MDS), locally linear embedding (LLE), and isometric mapping (IsoMap).
本节中我们会展⽰⼀些流形学习⽅法,底层主要依赖三个技巧:多维缩放(MDS)、本地线性嵌⼊(LLE)和等距映射(IsoMap)。
We begin with the standard imports:
载⼊需要的包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
Manifold Learning: "HELLO"
流形学习的“HELLO”
To make these concepts more clear, let's start by generating some two-dimensional data that we can use to define a
manifold. Here is a function that will create data in the shape of the word "HELLO":
为令这些概念更加清晰,让我们⾸先创建⼀些⼆维数据可以⽤来定义流形。下⾯是创建⼀个“HELLO”形状的数据的函数定义:
In [2]: def make_hello(N=1000, rseed=42):
#
“HELLO”
PNG
fig, ax = plt.subplots(figsize=(4, 1))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax.axis('off')
ax.text(0.5, 0.4, 'HELLO', va='center', ha='center', weight='bold', size=85)
fig.savefig('hello.png')
plt.close(fig)
创建⼀个
形状的图像,保存为
打开图像,在
字体附近随机绘制⼀些图像点
#
HELLO
from matplotlib.image import imread
data = imread('hello.png')[::-1, :, 0].T
rng = np.random.RandomState(rseed)
X = rng.rand(4 * N, 2)
i, j = (X * data.shape).astype(int).T
mask = (data[i, j] < 1)
X = X[mask]
X[:, 0] *= (data.shape[0] / data.shape[1])
X = X[:N]
return X[np.argsort(X[:, 0])]
Let's call the function and visualize the resulting data:
然后调⽤函数展⽰结果图像:
In [3]: X = make_hello(1000)
colorize = dict(c=X[:, 0], cmap=plt.cm.get_cmap('rainbow', 5))
plt.scatter(X[:, 0], X[:, 1], **colorize)
plt.axis('equal');
The output is two dimensional, and consists of points drawn in the shape of the word, "HELLO". This data form will help
us to see visually what these algorithms are doing.
输出结果是⼆维的,包括着沿着HELLO形状绘制的点。这个数据会帮助我们可视化的展⽰算法的⾏为。
Multidimensional Scaling (MDS)
多维缩放(MDS)
Looking at data like this, we can see that the particular choice of x and y values of the dataset are not the most
fundamental description of the data: we can scale, shrink, or rotate the data, and the "HELLO" will still be apparent. For
example, if we use a rotation matrix to rotate the data, the x and y values change, but the data is still fundamentally the
same:
看到这样的数据,我们可以想象,图中这些特殊选择的x和y值并不是数据最基础的描述:我们可以缩放或者旋转数据,这个HELLO形状仍
然会很明显。例如,如果我们如下旋转数据集:
In [4]: def rotate(X, angle):
theta = np.deg2rad(angle)
R = [[np.cos(theta), np.sin(theta)],
[-np.sin(theta), np.cos(theta)]]
return np.dot(X, R)
X2 = rotate(X, 20) + 5
plt.scatter(X2[:, 0], X2[:, 1], **colorize)
plt.axis('equal');
This tells us that the x and y values are not necessarily fundamental to the relationships in the data. What is fundamental,
in this case, is the distance between each point and the other points in the dataset. A common way to represent this is to
use a distance matrix: for N points, we construct an N × N array such that entry (i, j) contains the distance between
point i and point j. Let's use Scikit-Learn's efficient pairwise_distances function to do this for our original data:
上⾯的例⼦告诉我们数据集中的x和y值并不是数据关系中必不可少的基础成分。在这个情况下,最基础的是数据集每个点和其他点之间距
离。使⽤距离矩阵来表⽰是⼀种通⽤的⽅法:对于 个点,我们构建⼀个
数组,数组中的元素 指代的是点 和点 之间的距
离。下⾯我们使⽤Scikit-Learn中⾼效的 pairwise_distance 函数来为我们的原始数据创建距离矩阵:
N
N ×N
(i, j)
i
j
In [5]: from sklearn.metrics import pairwise_distances
D = pairwise_distances(X)
D.shape
Out[5]: (1000, 1000)
As promised, for our N=1,000 points, we obtain a 1000×1000 matrix, which can be visualized as shown here:
意料之中,对于我们N=1000个点,我们获得了⼀个1000×1000的矩阵,我们可以如下可视化这个矩阵:
In [6]: plt.imshow(D, zorder=2, cmap='Blues', interpolation='nearest')
plt.colorbar();
If we similarly construct a distance matrix for our rotated and translated data, we see that it is the same:
如果我们采⽤同样的⽅法对旋转投射后的数据构建⼀个距离矩阵,我们就可以发现它们是相同的:
In [7]: D2 = pairwise_distances(X2)
np.allclose(D, D2)
Out[7]: True
This distance matrix gives us a representation of our data that is invariant to rotations and translations, but the
visualization of the matrix above is not entirely intuitive. In the representation shown in this figure, we have lost any visible
sign of the interesting structure in the data: the "HELLO" that we saw before.
虽然这个距离矩阵为我们提供了数据对于旋转和投射不敏感的证据,但是对矩阵的可视化结果却是完全不直观的。在距离矩阵的可视化图
中,我们损失了任何数据关键结构中的可视标志:也就是前⾯我们能看到的“HELLO”。
Further, while computing this distance matrix from the (x, y) coordinates is straightforward, transforming the distances
back into x and y coordinates is rather difficult. This is exactly what the multidimensional scaling algorithm aims to do:
given a distance matrix between points, it recovers a D-dimensional coordinate representation of the data. Let's see how
it works for our distance matrix, using the precomputed dissimilarity to specify that we are passing a distance matrix:
还有就是,虽然从(x, y)坐标中计算得到距离矩阵是很直接的,但是将距离矩阵转换回(x, y)坐标却是⾮常困难的。这正是多维缩放算法的⽬
标:给定点之间的距离矩阵,将其还原成⼀个 维坐标的数据表⽰。在我们的距离矩阵上使⽤ precomputed 不相似度来指定我们传递的
是距离矩阵:
D
In [8]: from sklearn.manifold import MDS
model = MDS(n_components=2, dissimilarity='precomputed', random_state=1)
out = model.fit_transform(D)
plt.scatter(out[:, 0], out[:, 1], **colorize)
plt.axis('equal');
The MDS algorithm recovers one of the possible two-dimensional coordinate representations of our data, using only the
N × N distance matrix describing the relationship between the data points.
MDS
算法还原了我们数据的⼀种可能的⼆维坐标表⽰,其中仅仅使⽤了描述数据点之间关系的
N ×N
的距离矩阵。
MDS as Manifold Learning
进⾏流形学习
MDS
The usefulness of this becomes more apparent when we consider the fact that distance matrices can be computed from
data in any dimension. So, for example, instead of simply rotating the data in the two-dimensional plane, we can project it
into three dimensions using the following function (essentially a three-dimensional generalization of the rotation matrix
used earlier):
当我们考虑到距离矩阵可以从任何维度的数据中计算得到时,上述特性变得更加明显。例如,我们可以将“HELLO”投射到3维中⽽不是上⾯
的⼆维旋转,使⽤下⾯的函数(基本上就是前⾯的矩阵旋转函数的三维通⽤版本)即可实现:
In [9]: def random_projection(X, dimension=3, rseed=42):
assert dimension >= X.shape[1]
rng = np.random.RandomState(rseed)
C = rng.randn(dimension, dimension)
e, V = np.linalg.eigh(np.dot(C, C.T))
return np.dot(X, V[:X.shape[1]])
X3 = random_projection(X, 3)
X3.shape
Out[9]: (1000, 3)
Let's visualize these points to see what we're working with:
将这些三维中投射的点可视化出来:
In [10]: from mpl_toolkits import mplot3d
ax = plt.axes(projection='3d')
ax.scatter3D(X3[:, 0], X3[:, 1], X3[:, 2],
**colorize)
ax.view_init(azim=70, elev=50)
We can now ask the MDS estimator to input this three-dimensional data, compute the distance matrix, and then
determine the optimal two-dimensional embedding for this distance matrix. The result recovers a representation of the
original data:
我们可以将这个三维数据作为输⼊代⼊ MDS 评估器,计算距离矩阵,然后求出该距离矩阵最优化的⼆维表⽰。结果还原了原始数据的最基
础特征:
In [11]: model = MDS(n_components=2, random_state=1)
out3 = model.fit_transform(X3)
plt.scatter(out3[:, 0], out3[:, 1], **colorize)
plt.axis('equal');
This is essentially the goal of a manifold learning estimator: given high-dimensional embedded data, it seeks a lowdimensional representation of the data that preserves certain relationships within the data. In the case of MDS, the
quantity preserved is the distance between every pair of points.
这就是流形学习评估器的最基本⽬标:给定⾼维度的数据,它能找到⼀个低维度的数据表⽰,并且保留⼤部分数据之间的联系。在MDS算
法中,被保留下来的数量是每两个点之间的距离。
Nonlinear Embeddings: Where MDS Fails
⾮线性嵌⼊:MDS会失效
Our discussion thus far has considered linear embeddings, which essentially consist of rotations, translations, and
scalings of data into higher-dimensional spaces. Where MDS breaks down is when the embedding is nonlinear—that is,
when it goes beyond this simple set of operations. Consider the following embedding, which takes the input and contorts
it into an "S" shape in three dimensions:
⽬前为⽌我们讨论的都是线性嵌⼊,基本上就是包括将数据旋转、转换和缩放到⾼维空间。MDS会在⾮线性嵌⼊的情况下失效:这种情况
下超越了上述简单的操作。考虑下⾯这种嵌⼊,函数会将输⼊扭曲成⼀个三维空间中的“S”形状:
In [12]: def make_hello_s_curve(X):
t = (X[:, 0] - 2) * 0.75 * np.pi
x = np.sin(t)
y = X[:, 1]
z = np.sign(t) * (np.cos(t) - 1)
return np.vstack((x, y, z)).T
XS = make_hello_s_curve(X)
This is again three-dimensional data, but we can see that the embedding is much more complicated:
结果还是三维数据,但是我们可以看到这个嵌⼊⽐前⾯复杂多了:
In [13]: from mpl_toolkits import mplot3d
ax = plt.axes(projection='3d')
ax.scatter3D(XS[:, 0], XS[:, 1], XS[:, 2],
**colorize);
The fundamental relationships between the data points are still there, but this time the data has been transformed in a
nonlinear way: it has been wrapped-up into the shape of an "S."
数据点之间的基本关系仍然保留下来了,但是这次数据被转换成了⾮线性形状:它们被封装成了⼀个“S”的形状。
If we try a simple MDS algorithm on this data, it is not able to "unwrap" this nonlinear embedding, and we lose track of the
fundamental relationships in the embedded manifold:
如果我们在这之上尝试简单的MDS算法,它将⽆法对这种⾮线性嵌⼊进⾏解封装,也就是我们失去了这个嵌⼊流形的基础关系:
In [14]: from sklearn.manifold import MDS
model = MDS(n_components=2, random_state=2)
outS = model.fit_transform(XS)
plt.scatter(outS[:, 0], outS[:, 1], **colorize)
plt.axis('equal');
The best two-dimensional linear embeding does not unwrap the S-curve, but instead throws out the original y-axis.
⼆维最好的“线性”嵌⼊⽆法将S曲线解封装,⽽是直接返回了原始数据的y轴。
Nonlinear Manifolds: Locally Linear Embedding
⾮线性流形:本地线性嵌⼊
How can we move forward here? Stepping back, we can see that the source of the problem is that MDS tries to preserve
distances between faraway points when constructing the embedding. But what if we instead modified the algorithm such
that it only preserves distances between nearby points? The resulting embedding would be closer to what we want.
我们该如何进⾏下去呢?回想⼀下,我们发现问题的根源在于MDS保留了相隔很远的点之间的距离。如果我们修改⼀下算法,让它仅仅保
留附近的点之间的距离呢?结果的嵌⼊⽅式会更加接近我们希望得到的。
Visually, we can think of it as illustrated in this figure:
从下⾯的图像中我们可以看到两者之间的区别:
(LLE vs MDS linkages)
附录中⽣成图像的代码
Here each faint line represents a distance that should be preserved in the embedding. On the left is a representation of
the model used by MDS: it tries to preserve the distances between each pair of points in the dataset. On the right is a
representation of the model used by a manifold learning algorithm called locally linear embedding (LLE): rather than
preserving all distances, it instead tries to preserve only the distances between neighboring points: in this case, the
nearest 100 neighbors of each point.
上⾯每条淡⾊的线条代表着在嵌⼊算法中需要保留的距离。左图表⽰的MDS算法:它保留了数据集中每⼀对点之间的距离。右图展⽰的是
被称为本地线性嵌⼊(LLE)的⼀种流形学习算法:与其保留所有距离,它仅仅保留了那些邻近点之间的距离:在这个例⼦中,每个点都
有100个最邻近的点。
Thinking about the left panel, we can see why MDS fails: there is no way to flatten this data while adequately preserving
the length of every line drawn between the two points. For the right panel, on the other hand, things look a bit more
optimistic. We could imagine unrolling the data in a way that keeps the lengths of the lines approximately the same. This
is precisely what LLE does, through a global optimization of a cost function reflecting this logic.
再次考虑左图,我们可以发现MDS失效的原因:它⽆法将这些距离数据平铺出来,因为每两个点之间的距离都被保留了下来。对于右图来
说,情况乐观得多,我们可以想象成沿着S形状将数据展开出来,因为不沿着这个形状的两点之间的距离并没有被保留下来。这就是LLE的
原理,通过⼀个反映该逻辑的损失函数进⾏全局的优化。
LLE comes in a number of flavors; here we will use the modified LLE algorithm to recover the embedded two-dimensional
manifold. In general, modified LLE does better than other flavors of the algorithm at recovering well-defined manifolds
with very little distortion:
有许多的⽅式;这⾥我们采⽤了改良LLE(modified LLE)算法来还原嵌⼊的⼆维流形。在通常情况下,改良LLE在还原良好定义流形
的实践中要⽐其他算法的表现都要优异,造成的扭曲⾮常少:
LLE
In [15]: from sklearn.manifold import LocallyLinearEmbedding
model = LocallyLinearEmbedding(n_neighbors=100, n_components=2, method='modified',
eigen_solver='dense')
out = model.fit_transform(XS)
fig, ax = plt.subplots()
ax.scatter(out[:, 0], out[:, 1], **colorize)
ax.set_ylim(0.15, -0.15);
The result remains somewhat distorted compared to our original manifold, but captures the essential relationships in the
data!
结果和原始流形⽐较还是存在⼀些变形,但是它还是捕获了数据中的关键关系。
Some Thoughts on Manifold Methods
关于流形⽅法的⼀些思考
Though this story and motivation is compelling, in practice manifold learning techniques tend to be finicky enough that
they are rarely used for anything more than simple qualitative visualization of high-dimensional data.
虽然本节以及上⾯讲述很吸引⼈,但是实践中流形学习技巧却是⾮常挑剔的,导致它们很少在除了将⾼维度数据进⾏数值可视化领域外使
⽤。
The following are some of the particular challenges of manifold learning, which all contrast poorly with PCA:
In manifold learning, there is no good framework for handling missing data. In contrast, there are straightforward
iterative approaches for missing data in PCA.
In manifold learning, the presence of noise in the data can "short-circuit" the manifold and drastically change the
embedding. In contrast, PCA naturally filters noise from the most important components.
The manifold embedding result is generally highly dependent on the number of neighbors chosen, and there is
generally no solid quantitative way to choose an optimal number of neighbors. In contrast, PCA does not involve
such a choice.
In manifold learning, the globally optimal number of output dimensions is difficult to determine. In contrast, PCA lets
you find the output dimension based on the explained variance.
In manifold learning, the meaning of the embedded dimensions is not always clear. In PCA, the principal components
have a very clear meaning.
In manifold learning the computational expense of manifold methods scales as O[N^2] or O[N^3]. For PCA, there
exist randomized approaches that are generally much faster (though see the megaman package for some more
scalable implementations of manifold learning).
下⾯列出了⼀些流形学习的缺点,基本上都是与PCA⽐较的:
在流形学习中,没有好的框架来处理缺失数据。相反在PCA中,有很多直接和有效的⽅法实现缺失数据处理。
在流形学习中,数据中的噪声可能会导致流形“短路”从⽽剧烈的改变嵌⼊结果。相反,PCA可以在最重要成分之外很⾃然的过滤噪
声。
流形嵌⼊结果通常⾼度依赖于临近点个数的选择,并且没有通⽤可证的数值⽅法来选择邻近点数量的最优解。相反,PCA不存在这样
的选择。
在流形学习中,输出维度的全局最优解很难得到。相反,PCA通过可解释⽅差可以找到输出维度的全局最优解。
在流形学习中,嵌⼊维度的含义并不总是清晰的。PCA中的主成分有着⾮常明确的含义。
流形学习中流形算法的计算复杂度通常是O[N^2]或O[N^3]。⽽PCA存在随机⽅法,通常计算快许多(当然你也可以参看megaman
包,这⾥包含⼀些更具扩展性的流形学习的计算实现)。
With all that on the table, the only clear advantage of manifold learning methods over PCA is their ability to preserve
nonlinear relationships in the data; for that reason I tend to explore data with manifold methods only after first exploring
them with PCA.
有了上表列出的这些缺点,流形学习对⽐PCA只有⼀个明显的优点,就是它能保留数据中的⾮线性关系;因此,作者建议对数据完成了
PCA分析之后,再采⽤流形学习⽅法。
Scikit-Learn implements several common variants of manifold learning beyond Isomap and LLE: the Scikit-Learn
documentation has a nice discussion and comparison of them. Based on my own experience, I would give the following
recommendations:
For toy problems such as the S-curve we saw before, locally linear embedding (LLE) and its variants (especially
modified LLE), perform very well. This is implemented in sklearn.manifold.LocallyLinearEmbedding .
For high-dimensional data from real-world sources, LLE often produces poor results, and isometric mapping
(IsoMap) seems to generally lead to more meaningful embeddings. This is implemented in
sklearn.manifold.Isomap
For data that is highly clustered, t-distributed stochastic neighbor embedding (t-SNE) seems to work very well,
though can be very slow compared to other methods. This is implemented in sklearn.manifold.TSNE .
除了Isomap和LLE之外,还实现了⼀些常⻅的流形学习⽅法变体:Scikit-Learn⽂档中有⼀篇很好的针对它们的讨论和⽐较⽂
章。基于作者⾃⾝的经验,给出了下⾯的⼀些建议:
对于⼀些玩具问题,例如我们前⾯看到的S曲线,本地线性嵌⼊(LLE)和它的变体(特别是改良LLE)表现的很优秀。它们实现在
sklearn.manifold.LocallyLinearEmbedding 。
对于真实世界中的⾼维度数据,LLE经常产⽣很差的结果,⽽等距映射(IsoMap)通常会得到更加有意义的嵌⼊。它们实现在
sklearn.manifold.Isomap 。
对于⾼度聚集的数据,T-分布随机近邻嵌⼊(t-SNE)⼯作的较好,虽然相对其他⽅法来说它⾮常慢。它们实现在
sklearn.manifold.TSNE 。
Scikit-Learn
If you're interested in getting a feel for how these work, I'd suggest running each of the methods on the data in this
section.
如果你对于这些算法的⼯作原理感兴趣,作者建议在本节的数据例⼦上运⾏每⼀个⽅法来查看结果。
Example: Isomap on Faces
例⼦:⼈脸数据上使⽤Isomap
One place manifold learning is often used is in understanding the relationship between high-dimensional data points. A
common case of high-dimensional data is images: for example, a set of images with 1,000 pixels each can be thought of
as a collection of points in 1,000 dimensions – the brightness of each pixel in each image defines the coordinate in that
dimension.
流形学习经常被⽤来理解⾼维度数据点之间的关系。图像是⾼维度数据的常⻅场景:例如,⼀组的1000像素点的图像,其中的每⼀张都可
以被认为是1000维数据点的集合,每个像素的亮度定义了该维度的坐标值。
Here let's apply Isomap on some faces data. We will use the Labeled Faces in the Wild dataset, which we previously saw
in In-Depth: Support Vector Machines and In Depth: Principal Component Analysis. Running this command will download
the data and cache it in your home directory for later use:
下⾯我们将Isomap算法应⽤到⼀些⼈脸数据上。我们继续使⽤Wild数据集中的标签⼈脸数据,我们在前⾯的深⼊:⽀持向量机 and 深⼊:
主成分分析中都使⽤过它。使⽤下⾯的命令会下载数据并将其缓存在你的主⽬录中:
In [16]: from sklearn.datasets import fetch_lfw_people
faces = fetch_lfw_people(min_faces_per_person=30)
faces.data.shape
Out[16]: (2370, 2914)
We have 2,370 images, each with 2,914 pixels. In other words, the images can be thought of as data points in a 2,914dimensional space!
我们有2370张图像,每张都是2914个像素。换⾔之,每张图像都可以被认为是在2914维空间中的数据点的集合。
Let's quickly visualize several of these images to see what we're working with:
我们展⽰部分图像,看看我们的数据集是怎样的:
In [17]: fig, ax = plt.subplots(4, 8, subplot_kw=dict(xticks=[], yticks=[]))
for i, axi in enumerate(ax.flat):
axi.imshow(faces.images[i], cmap='gray')
We would like to plot a low-dimensional embedding of the 2,914-dimensional data to learn the fundamental relationships
between the images. One useful way to start is to compute a PCA, and examine the explained variance ratio, which will
give us an idea of how many linear features are required to describe the data:
我们希望绘制这些2914维数据的低维度嵌⼊,从⽽获得这些图像之间的基本关系。从计算PCA开始是⼀个不错的办法,然后检查可解释⽅
差的⽐率,能为我们提供描述数据所需的线性特征值数量的基本概念:
In [18]: from sklearn.decomposition import PCA as RandomizedPCA
model = RandomizedPCA(100).fit(faces.data)
plt.plot(np.cumsum(model.explained_variance_ratio_))
plt.xlabel('n components')
plt.ylabel('cumulative variance');
We see that for this data, nearly 100 components are required to preserve 90% of the variance: this tells us that the data
is intrinsically very high dimensional—it can't be described linearly with just a few components.
我们从上图可⻅,接近100个成分就能保留90%的差异:这告诉我们,这些数据本质上就是⾮常⾼维度的,它⽆法仅使⽤⼏个成分进⾏线
性描述。
When this is the case, nonlinear manifold embeddings like LLE and Isomap can be helpful. We can compute an Isomap
embedding on these faces using the same pattern shown before:
这种情况下,⾮线性流形嵌⼊如LLE和Isomap⽐较有帮助。我们可以使⽤Isomap嵌⼊来计算这些⼈脸数据的⼆维投射:
In [19]: from sklearn.manifold import Isomap
model = Isomap(n_components=2)
proj = model.fit_transform(faces.data)
proj.shape
Out[19]: (2370, 2)
The output is a two-dimensional projection of all the input images. To get a better idea of what the projection tells us, let's
define a function that will output image thumbnails at the locations of the projections:
输出结果是所有输⼊图像的⼆维投射。要获得这个⼆维投射的意义,我们定义⼀个函数,它会将相应的缩略图绘制在投射点的对应位置:
In [20]: from matplotlib import offsetbox
def plot_components(data, model, images=None, ax=None,
thumb_frac=0.05, cmap='gray'):
ax = ax or plt.gca()
proj = model.fit_transform(data)
ax.plot(proj[:, 0], proj[:, 1], '.k')
if images is not None:
min_dist_2 = (thumb_frac * max(proj.max(0) - proj.min(0))) ** 2
shown_images = np.array([2 * proj.max(0)])
for i in range(data.shape[0]):
dist = np.sum((proj[i] - shown_images) ** 2, 1)
if np.min(dist) < min_dist_2:
#
continue
shown_images = np.vstack([shown_images, proj[i]])
imagebox = offsetbox.AnnotationBbox(
offsetbox.OffsetImage(images[i], cmap=cmap),
proj[i])
ax.add_artist(imagebox)
如果两点距离太近,不显⽰
Calling this function now, we see the result:
调⽤函数获得结果:
In [21]: fig, ax = plt.subplots(figsize=(10, 10))
plot_components(faces.data,
model=Isomap(n_components=2),
images=faces.images[:, ::2, ::2])
The result is interesting: the first two Isomap dimensions seem to describe global image features: the overall darkness or
lightness of the image from left to right, and the general orientation of the face from bottom to top. This gives us a nice
visual indication of some of the fundamental features in our data.
这个结果很有趣:Isomap前两个维度看起来是在描述图像中的通⽤特征:从左到右是图像的亮度或暗度的变化,从上到下是图像中⼈脸⽅
向的变化。这能为我们提供关于数据关键特征的很好的可视化指⽰。
We could then go on to classify this data (perhaps using manifold features as inputs to the classification algorithm) as we
did in In-Depth: Support Vector Machines.
然后我们可以继续对数据进⾏分类(使⽤流形特征作为分类算法的输⼊),正如我们在深⼊:⽀持向量机中做的那样。
Example: Visualizing Structure in Digits
例⼦:⼿写数字的结构可视化
As another example of using manifold learning for visualization, let's take a look at the MNIST handwritten digits set. This
data is similar to the digits we saw in In-Depth: Decision Trees and Random Forests, but with many more pixels per
image. It can be downloaded from http://mldata.org/ with the Scikit-Learn utility:
下⾯我们使⽤MNIST⼿写数字数据集来作为流形学习可视化数据的另外⼀个例⼦。这个数据集与深⼊:决策树和随机森林中的类似,但是
每张图像有着更多的像素点。它可以使⽤Scikit-Learn⼯具从http://mldata.org/ 下载:
译者注:fetch_mldata已经⽆法从mldata.org中获得数据集,事实上mldata.org已经停⽌服务。同样使⽤fetch_openml替换也会发⽣错误,
原因猜测是⽆法描述吧。你可以在MNIST MAT⽂件下载本节所需的数据集⽂件,然后将它放置在$HOME/scikit_learn_data/mldata⽬录
中,如果⽬录不存在就创建它:
mkdir -p $HOME/scikit_learn_data/mldata
然后就不需要从⽹络中下载了。
In [22]: from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
mnist.data.shape
/home/wangy/anaconda3/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarnin
g: Function fetch_mldata is deprecated; fetch_mldata was deprecated in version 0.20 and will be remov
ed in version 0.22. Please use fetch_openml.
warnings.warn(msg, category=DeprecationWarning)
/home/wangy/anaconda3/lib/python3.7/site-packages/sklearn/utils/deprecation.py:85: DeprecationWarnin
g: Function mldata_filename is deprecated; mldata_filename was deprecated in version 0.20 and will be
removed in version 0.22. Please use fetch_openml.
warnings.warn(msg, category=DeprecationWarning)
Out[22]: (70000, 784)
This consists of 70,000 images, each with 784 pixels (i.e. the images are 28×28). As before, we can take a look at the
first few images:
我们看到这个数据集包含70000张图,每张图有784个像素点(即28×28规格)。惯例上我们先看看数据集前⾯的部分图像:
In [23]: fig, ax = plt.subplots(6, 8, subplot_kw=dict(xticks=[], yticks=[]))
for i, axi in enumerate(ax.flat):
axi.imshow(mnist.data[1250 * i].reshape(28, 28), cmap='gray_r')
This gives us an idea of the variety of handwriting styles in the dataset.
上图可以让我们⼤概看到⼿写数字的⼀些不同的⻛格。
Let's compute a manifold learning projection across the data. For speed here, we'll only use 1/30 of the data, which is
about ~2000 points (because of the relatively poor scaling of manifold learning, I find that a few thousand samples is a
good number to start with for relatively quick exploration before moving to a full calculation):
下⾯我们来计算这个数据集的流形学习投射。因为速度原因,这⾥仅使⽤了数据集的1/30,也就是⼤约2000个点(因为流形学习相对较⾼
的计算复杂度,作者认为⼏千个样本是⼀个⽐较合适的数据集⼤⼩,如果需要深⼊,再进⾏全样本集的计算):
仅适⽤
的数据,全数据集将花费很⻓的时间
In [24]: #
1/30
data = mnist.data[::30]
target = mnist.target[::30]
model = Isomap(n_components=2)
proj = model.fit_transform(data)
plt.scatter(proj[:, 0], proj[:, 1], c=target, cmap=plt.cm.get_cmap('jet', 10))
plt.colorbar(ticks=range(10))
plt.clim(-0.5, 9.5);
The resulting scatter plot shows some of the relationships between the data points, but is a bit crowded. We can gain
more insight by looking at just a single number at a time:
结果中的散点图展⽰了数据点之间的⼀些关系,但是看起来显得有点拥挤。我们可以通过⼀次只看⼀个数字来获得更加清晰的展⽰效果:
In [25]: from sklearn.manifold import Isomap
选择 的数字 进⾏展⽰分析
#
1/4
1
data = mnist.data[mnist.target == 1][::4]
fig, ax = plt.subplots(figsize=(10, 10))
model = Isomap(n_neighbors=5, n_components=2, eigen_solver='dense')
plot_components(data, model, images=data.reshape((-1, 28, 28)),
ax=ax, thumb_frac=0.05, cmap='gray_r')
The result gives you an idea of the variety of forms that the number "1" can take within the dataset. The data lies along a
broad curve in the projected space, which appears to trace the orientation of the digit. As you move up the plot, you find
ones that have hats and/or bases, though these are very sparse within the dataset. The projection lets us identify outliers
that have data issues: for example, pieces of the neighboring digits that snuck into the extracted images.
上⾯的结果展⽰了不同⻛格书写的数字⼀在散点图中的分布情况。这些数据点分布在⼀个宽的弧形区域中,观察可知沿着弧形变化的是数
字1的书写⽅向。图中中上部的数字1都戴着帽⼦和/或划了底座,但是它们通常离散于弧形区域之外。这个投射情况也能让我们发现⼀些数
据本⾝的缺陷问题:例如,邻近的数字部分被划到了提取出来的图像当中。
Now, this in itself may not be useful for the task of classifying digits, but it does help us get an understanding of the data,
and may give us ideas about how to move forward, such as how we might want to preprocess the data before building a
classification pipeline.
因此,这不是我们⽤来作为数字分类的好⼯具,但是它能帮助我们理解数据集本⾝,还可以为我们提供⼀些好的想法,例如在我们创建分
类器处理之前可以对数据进⾏预处理。
<
深⼊:主成分分析 | ⽬录 | 深⼊:k-均值聚类 >
Open in Colab
深⼊:流形学习 | ⽬录 | 深⼊:⾼斯混合模型 >
<
Open in Colab
In Depth: k-Means Clustering
深⼊:K均值聚类
In the previous few sections, we have explored one category of unsupervised machine learning models: dimensionality
reduction. Here we will move on to another class of unsupervised machine learning models: clustering algorithms.
Clustering algorithms seek to learn, from the properties of the data, an optimal division or discrete labeling of groups of
points.
在前⾯两节中,我们研究了⽆监督机器学习的⼀个⼤类:降维。下⾯我们将要继续讨论⽆监督机器学习的另⼀个⼤类:聚类算法。聚类算
法寻求通过数据的属性进⾏学习,然后获得数据点的优化分组或离散标签。
Many clustering algorithms are available in Scikit-Learn and elsewhere, but perhaps the simplest to understand is an
algorithm known as k-means clustering, which is implemented in sklearn.cluster.KMeans .
许多聚类算法在Scikit-Learn和其他包中都实现了,其中最容易理解的算法就是k均值聚类,它被实现在 sklearn.cluster.KMeans 包
中。
We begin with the standard imports:
例⾏导⼊包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
#
图表⻛格 seaborn
Introducing k-Means
均值聚类简介
k
The k-means algorithm searches for a pre-determined number of clusters within an unlabeled multidimensional dataset. It
accomplishes this using a simple conception of what the optimal clustering looks like:
The "cluster center" is the arithmetic mean of all the points belonging to the cluster.
Each point is closer to its own cluster center than to other cluster centers.
均值是⼀个在未标记的多维数据集中寻找确定分组数聚类的算法。该算法基于下述的优化聚类的基本概念:
“聚类中⼼”是该聚类所有点的算术平均。
每⼀个数据点距离它所属的聚类中⼼⽐距离其他聚类中⼼都要近。
K
Those two assumptions are the basis of the k-means model. We will soon dive into exactly how the algorithm reaches
this solution, but for now let's take a look at a simple dataset and see the k-means result.
这两个假设是k均值模型的基本原理。我们很快会了解算法是如何达到这个⽬标的,但是现在⾸先让我们使⽤⼀个简单的数据集来查看k均
值的结果。
First, let's generate a two-dimensional dataset containing four distinct blobs. To emphasize that this is an unsupervised
algorithm, we will leave the labels out of the visualization
⾸先,让我们⽣成⼀个⼆维的数据集,内含四个独⽴的群落。为了强调这是⼀个⽆监督算法,我们在图中没有展⽰标签:
译者注:下⾯代码为适应新版scikit-learn去除了 sample_generator 模块以避免产⽣警告。
In [2]: from sklearn.datasets import make_blobs
X, y_true = make_blobs(n_samples=300, centers=4,
cluster_std=0.60, random_state=0)
plt.scatter(X[:, 0], X[:, 1], s=50);
By eye, it is relatively easy to pick out the four clusters. The k-means algorithm does this automatically, and in ScikitLearn uses the typical estimator API:
通过⾁眼观察很容易能分出四个聚类出来。k均值算法会⾃动完成这个⼯作,在Scikit-Learn中有专⻔的评估器API:
In [3]: from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=4)
kmeans.fit(X)
y_kmeans = kmeans.predict(X)
Let's visualize the results by plotting the data colored by these labels. We will also plot the cluster centers as determined
by the k-means estimator:
然后将模型预测的结果使⽤不同的颜⾊绘制在图中。同时我们也会在图中画出每个聚类的中⼼点,这个中⼼点是评估器计算得到的:
In [4]: plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis')
centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5);
The good news is that the k-means algorithm (at least in this simple case) assigns the points to clusters very similarly to
how we might assign them by eye. But you might wonder how this algorithm finds these clusters so quickly! After all, the
number of possible combinations of cluster assignments is exponential in the number of data points—an exhaustive
search would be very, very costly. Fortunately for us, such an exhaustive search is not necessary: instead, the typical
approach to k-means involves an intuitive iterative approach known as expectation–maximization.
可以看到k均值算法(⾄少在这个简单的例⼦当中)对数据点聚类的⽅法与⾁眼⽅式相似。但是你可能会疑惑为什么算法能这么快找出所有
的聚类。因为由简单的数学可知,聚类的可能数量是数据点数量的指数量级。如果穷举所有的可能性的话,那将会⾮常⾮常慢。幸运的
是,我们并不需要这样的穷举:可以使⽤⼀种被称为最⼤期望算法的直观迭代⽅法来实现k均值算法。
k-Means Algorithm: Expectation–Maximization
均值算法:最⼤期望算法
k
Expectation–maximization (E–M) is a powerful algorithm that comes up in a variety of contexts within data science. kmeans is a particularly simple and easy-to-understand application of the algorithm, and we will walk through it briefly
here. In short, the expectation–maximization approach here consists of the following procedure:
1. Guess some cluster centers
2. Repeat until converged
A. E-Step: assign points to the nearest cluster center
B. M-Step: set the cluster centers to the mean
最⼤期望算法(E-M)是在数据科学领域⼴泛应⽤的强⼤算法。其中k均值是该算法最简单和易于理解的应⽤场景,这⾥简要介绍⼀下,最
⼤期望算法的步骤如下:
1. 随机猜测聚类中⼼点
2. 重复以下步骤
A. E步骤:将所有数据点分配到最近的聚类中⼼点上
B. M步骤:重新计算每个聚类的中⼼点
Here the "E-step" or "Expectation step" is so-named because it involves updating our expectation of which cluster each
point belongs to. The "M-step" or "Maximization step" is so-named because it involves maximizing some fitness function
that defines the location of the cluster centers—in this case, that maximization is accomplished by taking a simple mean
of the data in each cluster.
这⾥的“E步骤”也叫“期望步骤”,名字的由来是因为它是处理每个数据点归属的聚类的,即我们的期望。“M步骤”也叫“最⼤化步骤”,它的名
字由来是因为⽤来最⼤化定义聚类中⼼点的适配函数的,在这个情况中,最⼤化过程实际上就是对每个聚类的数据点取均值。
The literature about this algorithm is vast, but can be summarized as follows: under typical circumstances, each repetition
of the E-step and M-step will always result in a better estimate of the cluster characteristics.
要详细描述整个算法可能会很冗⻓,但是可以⽤⼀句话进⾏概括:在典型情况下,每次E步骤和M步骤的迭代都会导致更加准确的聚类特征
结果。
We can visualize the algorithm as shown in the following figure. For the particular initialization shown here, the clusters
converge in just three iterations. For an interactive version of this figure, refer to the code in the Appendix.
我们可以⽤下⾯图像可视化这个算法过程。对于⼀个左边图中设定的初始化中⼼点的情况,算法仅需要三次迭代的过程就可以完成聚类结
果。如果想要这个图像的交互式版本,请参考附录中的代码。
(run code in Appendix to generate image)
附录中⽣成图像的代码
The k-Means algorithm is simple enough that we can write it in a few lines of code. The following is a very basic
implementation:
均值算法很简单,我们可以使⽤数⾏代码就能实现它。下⾯是⼀种最基本的实现⽅式:
k
In [5]: from sklearn.metrics import pairwise_distances_argmin
def find_clusters(X, n_clusters, rseed=2):
# 1.
rng = np.random.RandomState(rseed)
i = rng.permutation(X.shape[0])[:n_clusters]
centers = X[i]
随机选取聚类中⼼点
while True:
# 2a.
labels = pairwise_distances_argmin(X, centers)
计算求出距离最近的中⼼点,标记相应数据点
求出每个聚类最新的中⼼点
# 2b.
new_centers = np.array([X[labels == i].mean(0)
for i in range(n_clusters)])
检查收敛,如果新中⼼点与原中⼼点相同,算法结束
# 2c.
if np.all(centers == new_centers):
break
centers = new_centers
return centers, labels
centers, labels = find_clusters(X, 4)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis');
Most well-tested implementations will do a bit more than this under the hood, but the preceding function gives the gist of
the expectation–maximization approach.
当然存在很多良好测试的实现⽅式会⽐上⾯的函数更加健壮,但是上⾯的代码给出了最⼤期望算法的纲要。
Caveats of expectation–maximization
最⼤期望算法的⼀些注意事项
There are a few issues to be aware of when using the expectation–maximization algorithm.
在使⽤最⼤期望算法时,有⼀些注意事项需要留意。
The globally optimal result may not be achieved
全局最优解可能不可得
First, although the E–M procedure is guaranteed to improve the result in each step, there is no assurance that it will lead
to the global best solution. For example, if we use a different random seed in our simple procedure, the particular starting
guesses lead to poor results:
⾸先,虽然E-M算法能保证每次迭代都改善结果,但是它并不能保证最终会产⽣全局最优解。例如,如果我们使⽤了不同的随机种⼦,这
个初始化值可能会产⽣不良的结果:
In [6]: centers, labels = find_clusters(X, 4, rseed=0)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis');
Here the E–M approach has converged, but has not converged to a globally optimal configuration. For this reason, it is
common for the algorithm to be run for multiple starting guesses, as indeed Scikit-Learn does by default (set by the
n_init parameter, which defaults to 10).
上⾯例⼦中E-M算法已经收敛了,但是没有收敛到⼀个全局最优的结果上。因此,该算法需要在不同的初始化条件下运⾏多次,⽽事实上
Scikit-Learn默认就会这样做(设置 n_init 参数,默认为10)。
The number of clusters must be selected beforehand
聚类的数量必须预先选择
Another common challenge with k-means is that you must tell it how many clusters you expect: it cannot learn the
number of clusters from the data. For example, if we ask the algorithm to identify six clusters, it will happily proceed and
find the best six clusters:
均值算法的另⼀个常⻅挑战是你必须告诉它聚类的个数:它并不能够从数据中学习得到聚类的数量。例如,如果我们要求算法计算6个聚
类,它会很成功的运⾏得到相应的结果:
k
In [7]: labels = KMeans(6, random_state=0).fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis');
Whether the result is meaningful is a question that is difficult to answer definitively; one approach that is rather intuitive,
but that we won't discuss further here, is called silhouette analysis.
产⽣的结果是否有意义通常是⼀个难以准确回答的问题;有⼀种⽐较直观的⽅法可以回答这个问题,叫做轮廓分析,我们这⾥不会详述。
Alternatively, you might use a more complicated clustering algorithm which has a better quantitative measure of the
fitness per number of clusters (e.g., Gaussian mixture models; see In Depth: Gaussian Mixture Models) or which can
choose a suitable number of clusters (e.g., DBSCAN, mean-shift, or affinity propagation, all available in the
sklearn.cluster submodule)
还有其他的选择,你可以使⽤更加复杂的聚类算法,它们能提供更好的数值度量,如聚类的数量(例如,⾼斯混合模型,参⻅深⼊:⾼斯
混合模型),或者能够选择合适的聚类的数量(例如,密度聚类DBSCAN,均值漂移mean-shift或者亲和⼒传播affinity propagation,它们
都实现在 sklearn.cluster ⼦模块中)。
k-means is limited to linear cluster boundaries
k
均值只限于线性聚类边界
The fundamental model assumptions of k-means (points will be closer to their own cluster center than to others) means
that the algorithm will often be ineffective if the clusters have complicated geometries.
均值的基本模型假设(数据点距离它所属的中⼼点⽐其他中⼼点都要近)意味着这个算法通常会在聚类有着复杂的集合结构的情况下失
效。
k
In particular, the boundaries between k-means clusters will always be linear, which means that it will fail for more
complicated boundaries. Consider the following data, along with the cluster labels found by the typical k-means
approach:
严格来说,k均值算法聚类的边界总是线性的,因此在更复杂边界的情况下将⽆法使⽤。考虑下⾯的数据,然后使⽤k均值算法获得聚类结
果:
In [8]: from sklearn.datasets import make_moons
X, y = make_moons(200, noise=.05, random_state=0)
In [9]: labels = KMeans(2, random_state=0).fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis');
This situation is reminiscent of the discussion in In-Depth: Support Vector Machines, where we used a kernel
transformation to project the data into a higher dimension where a linear separation is possible. We might imagine using
the same trick to allow k-means to discover non-linear boundaries.
这种情形让我们想起在深⼊:⽀持向量机中的讨论,我们可以使⽤核转换将数据投射到更⾼的维度上,令线性分类器可以⼯作。我们也可
以在k均值算法上使⽤相同的技巧,令k均值能够⽀持⾮线性的边界。
One version of this kernelized k-means is implemented in Scikit-Learn within the SpectralClustering estimator. It
uses the graph of nearest neighbors to compute a higher-dimensional representation of the data, and then assigns labels
using a k-means algorithm:
这种核化的k均值算法其中⼀个版本是Scikit-Learn中实现的 SpectralClustering 评估器。它使⽤最近邻图像来计算更⾼维度的数据表
⽰,然后使⽤k均值算法来标记数据点:
In [10]: from sklearn.cluster import SpectralClustering
model = SpectralClustering(n_clusters=2, affinity='nearest_neighbors',
assign_labels='kmeans')
labels = model.fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels,
s=50, cmap='viridis');
/home/wangy/anaconda3/lib/python3.7/site-packages/sklearn/manifold/_spectral_embedding.py:236: UserWa
rning: Graph is not fully connected, spectral embedding may not work as expected.
warnings.warn("Graph is not fully connected, spectral embedding"
We see that with this kernel transform approach, the kernelized k-means is able to find the more complicated nonlinear
boundaries between clusters.
我们看到经过这种核转换⽅法后,核化的k均值算法能够实现更加复杂的⾮线性聚类边界。
k-means can be slow for large numbers of samples
k
均值在⼤数据集上会慢
Because each iteration of k-means must access every point in the dataset, the algorithm can be relatively slow as the
number of samples grows. You might wonder if this requirement to use all data at each iteration can be relaxed; for
example, you might just use a subset of the data to update the cluster centers at each step. This is the idea behind batchbased k-means algorithms, one form of which is implemented in sklearn.cluster.MiniBatchKMeans . The
interface for this is the same as for standard KMeans ; we will see an example of its use as we continue our discussion.
因为k均值算法上每次迭代都要获取数据集中的每个点,当样本量增加时,算法性能会下降。你可能会觉得是否每次迭代都有必要使⽤全部
数据点;例如,每次迭代仅使⽤数据的⼀个⼦集来更新聚类中⼼点。这种想法就是基于批次的k均值算法,它被实现在
sklearn.cluster.MiniBatchKMeans 当中。批次算法的接⼝与标准k均值算法⼀致;接下来会看到⼀个例⼦。
Examples
例⼦
Being careful about these limitations of the algorithm, we can use k-means to our advantage in a wide variety of
situations. We'll now take a look at a couple examples.
当留意了上述的注意事项后,我们可以将k均值算法应⽤到⼴泛的场景中。下⾯我们看⼏个例⼦。
Example 1: k-means on digits
例1:⼿写数字上应⽤k均值
To start, let's take a look at applying k-means on the same simple digits data that we saw in In-Depth: Decision Trees and
Random Forests and In Depth: Principal Component Analysis. Here we will attempt to use k-means to try to identify
similar digits without using the original label information; this might be similar to a first step in extracting meaning from a
new dataset about which you don't have any a priori label information.
⾸先,我们来看看在简单的⼿写数字数据集上应⽤k均值的情况,数据集与我们在深⼊:随机森林和深⼊:主成分分析中看到的⼀样。不过
我们将尝试不使⽤原始标签信息的情况下,应⽤k均值算法来分辨⼿写的数字;这情况就像我们在遇到⼀个没有任何初始标记的数据时,我
们希望⾸先从中提取出有意义的信息⼀样。
We will start by loading the digits and then finding the KMeans clusters. Recall that the digits consist of 1,797 samples
with 64 features, where each of the 64 features is the brightness of one pixel in an 8×8 image:
当然我们需要载⼊数据然后找到其k均值聚类结果。回想⼀下,我们知道数据有1797个样本,每个样本有64个特征,这些特征代表着8×8图
像中每个像素点的亮度:
In [11]: from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shape
Out[11]: (1797, 64)
The clustering can be performed as we did before:
就像前⾯那样,我们进⾏k均值聚类:
In [12]: kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits.data)
kmeans.cluster_centers_.shape
Out[12]: (10, 64)
The result is 10 clusters in 64 dimensions. Notice that the cluster centers themselves are 64-dimensional points, and can
themselves be interpreted as the "typical" digit within the cluster. Let's see what these cluster centers look like:
结果是在64维空间中的10个聚类。注意每个聚类的中⼼点都是64维空间中的⼀个点,我们可以将它们看做每个聚类的“典型”数字。让我们
将这些中⼼点数字画出来:
In [13]: fig, ax = plt.subplots(2, 5, figsize=(8, 3))
centers = kmeans.cluster_centers_.reshape(10, 8, 8)
for axi, center in zip(ax.flat, centers):
axi.set(xticks=[], yticks=[])
axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)
We see that even without the labels, KMeans is able to find clusters whose centers are recognizable digits, with
perhaps the exception of 1 and 8.
上⾯的结果表明,甚⾄不需要标记, K均值 算法就已经能够分出聚类,并且它们的中⼼点都是可以识别的数字,可能1和8稍微模糊点。
Because k-means knows nothing about the identity of the cluster, the 0–9 labels may be permuted. We can fix this by
matching each learned cluster label with the true labels found in them:
因为k均值算法根本不知道这些聚类的标记,因此0-9的标签不是按照顺序排列的。我们可以将真实的标签和学习到的聚类标签进⾏对应:
In [14]: from scipy.stats import mode
labels = np.zeros_like(clusters)
for i in range(10):
mask = (clusters == i)
labels[mask] = mode(digits.target[mask])[0]
Now we can check how accurate our unsupervised clustering was in finding similar digits within the data:
现在我们可以检查这个⽆监督聚类算法结果的准确性了:
In [15]: from sklearn.metrics import accuracy_score
accuracy_score(digits.target, labels)
Out[15]: 0.7952142459654981
With just a simple k-means algorithm, we discovered the correct grouping for 80% of the input digits! Let's check the
confusion matrix for this:
通过⼀个简单的k均值算法,我们就可以对80%左右的输⼊数据进⾏正确的分组。让我们再看看相应的混淆矩阵:
In [16]: from sklearn.metrics import confusion_matrix
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
xticklabels=digits.target_names,
yticklabels=digits.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');
As we might expect from the cluster centers we visualized before, the main point of confusion is between the eights and
ones. But this still shows that using k-means, we can essentially build a digit classifier without reference to any known
labels!
正如我们前⾯展⽰的聚类中⼼所预期到的那样,主要混淆的地⽅是在处在数字8和数字1之间。但是这还是能证明我们使⽤k均值就能构建⼀
个分类器,⽽不需要参考任何已知的标签。
Just for fun, let's try to push this even farther. We can use the t-distributed stochastic neighbor embedding (t-SNE)
algorithm (mentioned in In-Depth: Manifold Learning) to pre-process the data before performing k-means. t-SNE is a
nonlinear embedding algorithm that is particularly adept at preserving points within clusters. Let's see how it does:
为了更加有趣⼀点,我们再继续深⼊⼀点。我们使⽤t分布随机近邻嵌⼊(t-SNE)算法(参⻅深⼊:流形学习)在k均值算法之前来对数据
进⾏预处理。t-SNE是⼀个⾮线性嵌⼊算法,特别适合⽤来保留聚类的数据点。让我们来看看怎么做:
In [17]: from sklearn.manifold import TSNE
投射数据点,本步骤可能需要执⾏⼀段时间
#
tsne = TSNE(n_components=2, init='random', random_state=0)
digits_proj = tsne.fit_transform(digits.data)
计算聚类
#
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits_proj)
排列标签
#
labels = np.zeros_like(clusters)
for i in range(10):
mask = (clusters == i)
labels[mask] = mode(digits.target[mask])[0]
计算准确率
#
accuracy_score(digits.target, labels)
Out[17]: 0.9371174179187535
That's nearly 92% classification accuracy without using the labels. This is the power of unsupervised learning when used
carefully: it can extract information from the dataset that it might be difficult to do by hand or by eye.
这能在不使⽤标签的情况下达到超过93%的分类准确率。这体现了恰当的使⽤⽆监督学习⽅法的威⼒:它能从数据集中提取出关键信息,
⽽这很难使⽤⼿⼯或⾁眼完成。
Example 2: k-means for color compression
例2:k均值进⾏颜⾊压缩
One interesting application of clustering is in color compression within images. For example, imagine you have an image
with millions of colors. In most images, a large number of the colors will be unused, and many of the pixels in the image
will have similar or even identical colors.
聚类还有⼀个有趣的应⽤是在图像的颜⾊压缩领域。例如设想你有⼀张图像包含⼀百万种颜⾊。在⼤多数图像中,⼤量的颜⾊都没有被⽤
到,并且图像中很多的像素都有着相似甚⾄相同的颜⾊。
For example, consider the image shown in the following figure, which is from the Scikit-Learn datasets module (for
this to work, you'll have to have the pillow Python package installed).
例如下⾯这张图像,是Scikit-Learn的 datasets 模块⾃带的(要使下⾯例⼦能成功运⾏,你需要安装 pillow 包)。
注:需要安装
包
In [18]: #
pillow
from sklearn.datasets import load_sample_image
china = load_sample_image("china.jpg")
ax = plt.axes(xticks=[], yticks=[])
ax.imshow(china);
The image itself is stored in a three-dimensional array of size (height, width, RGB) , containing red/blue/green
contributions as integers from 0 to 255:
这张图像存储在⼀个尺⼨为 (⾼度, 宽度, RGB) 的三维数组之中,包含了红/蓝/绿的数值,取值范围是0-255:
In [19]: china.shape
Out[19]: (427, 640, 3)
One way we can view this set of pixels is as a cloud of points in a three-dimensional color space. We will reshape the
data to [n_samples x n_features] , and rescale the colors so that they lie between 0 and 1:
我们也可以将这张图像的像素看成是⼀个三维颜⾊空间中的数据点集合。下⾯我们将数组变形为 [n_samples x n_features] 形状,
然后将颜⾊值转换为取值范围是0-1之间:
In [20]: data = china / 255.0 # use 0...1 scale
data = data.reshape(427 * 640, 3)
data.shape
Out[20]: (273280, 3)
We can visualize these pixels in this color space, using a subset of 10,000 pixels for efficiency:
我们可以在这个颜⾊空间中将这些像素点可视化出来,为了效率起⻅,仅选择10000个像素的⼦数据集:
In [21]: def plot_pixels(data, title, colors=None, N=10000):
if colors is None:
colors = data
选择随机⼦数据集
#
rng = np.random.RandomState(0)
i = rng.permutation(data.shape[0])[:N]
colors = colors[i]
R, G, B = data[i].T
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
ax[0].scatter(R, G, color=colors, marker='.')
ax[0].set(xlabel='Red', ylabel='Green', xlim=(0, 1), ylim=(0, 1))
ax[1].scatter(R, B, color=colors, marker='.')
ax[1].set(xlabel='Red', ylabel='Blue', xlim=(0, 1), ylim=(0, 1))
fig.suptitle(title, size=20);
In [22]: plot_pixels(data, title='Input color space: 16 million possible colors')
Now let's reduce these 16 million colors to just 16 colors, using a k-means clustering across the pixel space. Because we
are dealing with a very large dataset, we will use the mini batch k-means, which operates on subsets of the data to
compute the result much more quickly than the standard k-means algorithm:
下⾯我们将1600万颜⾊调整为16⾊,仅仅需要在像素空间上使⽤k均值聚类算法。因为我们⾯对的是⼀个⾮常巨⼤的数据集,我们将会使
⽤批次k均值算法,该算法每次迭代只会在数据⼦集上进⾏计算,⽐标准的k均值算法要快得多:
译者注:下⾯的warning已经被Numpy修复,因此代码中做了注释。
In [23]: #import warnings; warnings.simplefilter('ignore')
# Fix NumPy issues.
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(16)
kmeans.fit(data)
new_colors = kmeans.cluster_centers_[kmeans.predict(data)]
plot_pixels(data, colors=new_colors,
title="Reduced color space: 16 colors")
The result is a re-coloring of the original pixels, where each pixel is assigned the color of its closest cluster center.
Plotting these new colors in the image space rather than the pixel space shows us the effect of this:
得到的结果是重新设置后的原始像素的颜⾊,也就是每个像素的颜⾊被设置成了它所属聚类的中⼼点值。将这些新设置颜⾊的像素转换回
图像空间,然后展⽰出来:
In [24]: china_recolored = new_colors.reshape(china.shape)
fig, ax = plt.subplots(1, 2, figsize=(16, 6),
subplot_kw=dict(xticks=[], yticks=[]))
fig.subplots_adjust(wspace=0.05)
ax[0].imshow(china)
ax[0].set_title('Original Image', size=16)
ax[1].imshow(china_recolored)
ax[1].set_title('16-color Image', size=16);
Some detail is certainly lost in the rightmost panel, but the overall image is still easily recognizable. This image on the
right achieves a compression factor of around 1 million! While this is an interesting application of k-means, there are
certainly better way to compress information in images. But the example shows the power of thinking outside of the box
with unsupervised methods like k-means.
当然右图中损失了⼀些细节,但整体上图像还是很容易辨认的。运⽤这个⽅法右图获得了⼀个⼤约1百万的压缩⽐率。虽然这是k均值算法
的⼀个有趣的应⽤场景,但是显然压缩图像信息还有更好的⽅法。这个例⼦为我们展现了类似k均值这样的⽆监督⽅法还能在⼀些意想之外
的场景中发挥作⽤。
<
深⼊:流形学习 | ⽬录 | 深⼊:⾼斯混合模型 >
Open in Colab
<
深⼊:k-均值聚类 | ⽬录 | 深⼊:核密度估计 >
Open in Colab
In Depth: Gaussian Mixture Models
深⼊:⾼斯混合模型
The k-means clustering model explored in the previous section is simple and relatively easy to understand, but its
simplicity leads to practical challenges in its application. In particular, the non-probabilistic nature of k-means and its use
of simple distance-from-cluster-center to assign cluster membership leads to poor performance for many real-world
situations. In this section we will take a look at Gaussian mixture models (GMMs), which can be viewed as an extension
of the ideas behind k-means, but can also be a powerful tool for estimation beyond simple clustering.
前⼀⼩节讨论的k均值聚类模型是简单和相对容易理解的,但是简单的代价就是在实际应⽤中会遇到挑战。具体来说,k均值的⾮概率本质
和其简单的依据与中⼼点距离来划分聚类的⽅式,决定了在很多真实世界情况中表现很不理想。本节中我们会学习⾼斯混合模型,它被认
为是k均值算法的⼀种拓展,且能作为超越简单聚类的⼀个强⼤⼯具。
We begin with the standard imports:
导⼊包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
Motivating GMM: Weaknesses of k-Means
使⽤GMM:k均值的弱点
Let's take a look at some of the weaknesses of k-means and think about how we might improve the cluster model. As we
saw in the previous section, given simple, well-separated data, k-means finds suitable clustering results.
让我们先来看⼀下k均值算法的弱点,然后考虑如何能够改进它。正如上⼀节中我们看到的,给定简单的良好分离的数据,k均值能够找到
正确的聚类结果。
For example, if we have simple blobs of data, the k-means algorithm can quickly label those clusters in a way that closely
matches what we might do by eye:
例如我们有⼀些简单的数据群落,k均值算法能够迅速的标记这些聚类,得到的结果符合我们⾁眼观测的情况:
译者注:下⾯代码为适应新版scikit-learn去除了 sample_generator 模块以避免产⽣警告。
⽣成数据
In [2]: #
from sklearn.datasets import make_blobs
X, y_true = make_blobs(n_samples=400, centers=4,
cluster_std=0.60, random_state=0)
X = X[:, ::-1] # flip axes for better plotting
使⽤ 均值进⾏聚类,并绘制结果
In [3]: #
k
from sklearn.cluster import KMeans
kmeans = KMeans(4, random_state=0)
labels = kmeans.fit(X).predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');
From an intuitive standpoint, we might expect that the clustering assignment for some points is more certain than others:
for example, there appears to be a very slight overlap between the two middle clusters, such that we might not have
complete confidence in the cluster assigment of points between them. Unfortunately, the k-means model has no intrinsic
measure of probability or uncertainty of cluster assignments (although it may be possible to use a bootstrap approach to
estimate this uncertainty). For this, we must think about generalizing the model.
直觉上来看,我们会发现图中某些数据点的聚类结果会⽐其他点的结果更加确定:例如图中中间两个聚类之间有很少量的重叠部分,这附
近的数据点我们并不能⾮常确定其从属于其相应的聚类中。
One way to think about the k-means model is that it places a circle (or, in higher dimensions, a hyper-sphere) at the
center of each cluster, with a radius defined by the most distant point in the cluster. This radius acts as a hard cutoff for
cluster assignment within the training set: any point outside this circle is not considered a member of the cluster. We can
visualize this cluster model with the following function:
我们考虑k均值算法时,可以认为它是以中⼼点为圆⼼(球⼼),距离中⼼点最远的点的距离为半径的⼀个圆(或者在⾼维空间中是⼀个超
球体)。这些范围被认为是训练集聚类的硬边界:任何边界外的数据点都不会被认为是属于该聚类的。通过下⾯的函数可以将这个想法可
视化出来:
In [4]: from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
def plot_kmeans(kmeans, X, n_clusters=4, rseed=0, ax=None):
labels = kmeans.fit_predict(X)
绘制输⼊数据点
#
ax = ax or plt.gca()
ax.axis('equal')
ax.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis', zorder=2)
绘制 均值模型
#
k
centers = kmeans.cluster_centers_
radii = [cdist(X[labels == i], [center]).max()
for i, center in enumerate(centers)]
for c, r in zip(centers, radii):
ax.add_patch(plt.Circle(c, r, fc='#CCCCCC', lw=3, alpha=0.5, zorder=1))
In [5]: kmeans = KMeans(n_clusters=4, random_state=0)
plot_kmeans(kmeans, X)
An important observation for k-means is that these cluster models must be circular: k-means has no built-in way of
accounting for oblong or elliptical clusters. So, for example, if we take the same data and transform it, the cluster
assignments end up becoming muddled:
均值⼀个重要的特点是这些聚类模型必须是圆形的:k均值没有內建的⽅式来处理⻓⽅形或者椭圆形的聚类。因此如果我们将同样的数据
转换⼀下,聚类的结果将会变得混乱起来:
k
In [6]: rng = np.random.RandomState(13)
X_stretched = np.dot(X, rng.randn(2, 2))
kmeans = KMeans(n_clusters=4, random_state=0)
plot_kmeans(kmeans, X_stretched)
By eye, we recognize that these transformed clusters are non-circular, and thus circular clusters would be a poor fit.
Nevertheless, k-means is not flexible enough to account for this, and tries to force-fit the data into four circular clusters.
This results in a mixing of cluster assignments where the resulting circles overlap: see especially the bottom-right of this
plot. One might imagine addressing this particular situation by preprocessing the data with PCA (see In Depth: Principal
Component Analysis), but in practice there is no guarantee that such a global operation will circularize the individual data.
⾁眼观察可知转换后的聚类不是圆形的,因此圆形的聚类模型会导致不理想的拟合。然⽽,k均值并不具有这样的灵活性,仍然按照圆形将
数据点分在四个聚类当中。这样导致的结果会是圆形的⼤⾯积重合:图中右下⻆的部分很清晰地展⽰了这⼀点。⼀种可能的解决⽅案是使
⽤PCA(参⻅深⼊:主成分分析)对数据进⾏预处理,但是实践当中并不能保证这样的全局操作能对所有数据集都能产⽣圆形的模型聚
类。
These two disadvantages of k-means—its lack of flexibility in cluster shape and lack of probabilistic cluster assignment—
mean that for many datasets (especially low-dimensional datasets) it may not perform as well as you might hope.
均值的这两个缺点,模型形状缺乏灵活性和⽆概率聚类本质,意味着对于很多数据集(特别是低维度数据集)来说,它可能不会按照预期
那样⼯作。
k
You might imagine addressing these weaknesses by generalizing the k-means model: for example, you could measure
uncertainty in cluster assignment by comparing the distances of each point to all cluster centers, rather than focusing on
just the closest. You might also imagine allowing the cluster boundaries to be ellipses rather than circles, so as to account
for non-circular clusters. It turns out these are two essential components of a different type of clustering model, Gaussian
mixture models.
也许你会想到通过对k均值模型进⾏泛化来解决这些缺点:例如你可以通过对每个数据点计算其与所有的聚类中⼼点的距离来测算不确定
度,⽽不是仅考虑最近的中⼼点。你也可以将聚类边界泛化成椭圆⽽不是圆形来适配⾮圆形聚类。这两个泛化技巧发展成为另外⼀个的聚
类模型,⾼斯混合模型。
Generalizing E–M: Gaussian Mixture Models
泛化期望最⼤化算法:⾼斯混合模型
A Gaussian mixture model (GMM) attempts to find a mixture of multi-dimensional Gaussian probability distributions that
best model any input dataset. In the simplest case, GMMs can be used for finding clusters in the same manner as kmeans:
⾼斯混合模型(GMM)试图在输⼊数据集中找到多维⾼斯概率混合分布模型。在最简单情况下,GMM能够像k均值⼀样的对数据进⾏聚
类:
译者注:新版Scikit-Learn中,GMM已经改为GaussianMixture,下⾯代码相应修改。
In [7]: from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=4).fit(X)
labels = gmm.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis');
But because GMM contains a probabilistic model under the hood, it is also possible to find probabilistic cluster assignments—in ScikitLearn this is done using the predict_proba method. This returns a matrix of size [n_samples, n_clusters] which measures
the probability that any point belongs to the given cluster:
但是因为⾼斯混合模型内置了概率模型,因此能使⽤Scikit-Learn中的 predict_proba ⽅法获得聚类的概率数据。该⽅法会返回⼀个
[n_samples, n_clusters] 的数组,⾥⾯包含每个数据点从属于每个聚类的概率值:
In [8]: probs = gmm.predict_proba(X)
print(probs[:5].round(3))
[[0.531 0.
[0.
1.
[0.
1.
[1.
0.
[0.
1.
0.469 0.
0.
0.
0.
0.
0.
0.
0.
0.
]
]
]
]
]]
We can visualize this uncertainty by, for example, making the size of each point proportional to the certainty of its
prediction; looking at the following figure, we can see that it is precisely the points at the boundaries between clusters that
reflect this uncertainty of cluster assignment:
⽐⽅说我们可以将这种不确定度可视化成每个数据点的⼤⼩,越确定的⾯积越⼤;就像下图所⽰,在两个聚类边界附近的数据点⼩⼀些,
表明不确定性更⾼:
平⽅概率表⽰⾯积差别
In [9]: size = 50 * probs.max(1) ** 2 #
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', s=size);
Under the hood, a Gaussian mixture model is very similar to k-means: it uses an expectation–maximization approach
which qualitatively does the following:
1. Choose starting guesses for the location and shape
2. Repeat until converged:
A. E-step: for each point, find weights encoding the probability of membership in each cluster
B. M-step: for each cluster, update its location, normalization, and shape based on all data points, making use of
the weights
底层实现上,⾼斯混合模型与k均值很相似:它也使⽤期望最⼤化算法,其算法步骤如下:
1. 选择初始位置和形状
2. 重复以下步骤:
A. E-step:对每个数据点,找到其从属于每个聚类的概率值
B. M-step:对每个聚类,依据所有从属的数据点,使⽤权重值计算更新它的位置,标准化值和形状
The result of this is that each cluster is associated not with a hard-edged sphere, but with a smooth Gaussian model. Just
as in the k-means expectation–maximization approach, this algorithm can sometimes miss the globally optimal solution,
and thus in practice multiple random initializations are used.
这样算法得到的结果,每个聚类不是依据硬边界的圆形来区分,⽽是依据平滑的⾼斯模型来区分。不过与k均值的期望最⼤化算法⼀样,上
述算法有时也会得不到全局最优解,因此实践中也需要使⽤多个随机的初始值。
Let's create a function that will help us visualize the locations and shapes of the GMM clusters by drawing ellipses based
on the GMM output:
下⾯我们创建⼀个函数来帮助我们根据GMM输出结果来绘制聚类的位置和形状:
译者注:新版Scikit-Learn的GaussianMixture评估器中已经不再使⽤covar属性了,代码中修改为covariances。
In [10]: from matplotlib.patches import Ellipse
def draw_ellipse(position, covariance, ax=None, **kwargs):
"""
"""
ax = ax or plt.gca()
根据给定的位置和协⽅差绘制模型椭圆
将协⽅差转换为主坐标轴
#
if covariance.shape == (2, 2):
U, s, Vt = np.linalg.svd(covariance)
angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
width, height = 2 * np.sqrt(s)
else:
angle = 0
width, height = 2 * np.sqrt(covariance)
绘制椭圆
#
for nsig in range(1, 4):
ax.add_patch(Ellipse(position, nsig * width, nsig * height,
angle, **kwargs))
def plot_gmm(gmm, X, label=True, ax=None):
ax = ax or plt.gca()
labels = gmm.fit(X).predict(X)
if label:
ax.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='viridis', zorder=2)
else:
ax.scatter(X[:, 0], X[:, 1], s=40, zorder=2)
ax.axis('equal')
w_factor = 0.2 / gmm.weights_.max()
for pos, covar, w in zip(gmm.means_, gmm.covariances_, gmm.weights_):
draw_ellipse(pos, covar, alpha=w * w_factor)
With this in place, we can take a look at what the four-component GMM gives us for our initial data:
有了这些函数,我们可以对前⾯的数据集在GMM下的模型进⾏展⽰:
In [11]: gmm = GaussianMixture(n_components=4, random_state=42)
plot_gmm(gmm, X)
Similarly, we can use the GMM approach to fit our stretched dataset; allowing for a full covariance the model will fit even
very oblong, stretched-out clusters:
相应的,可以使⽤GMM⽅法来拟合拉伸的数据集;设置完全协⽅差可令模型能拟合⻓条形、拉伸的聚类:
In [12]: gmm = GaussianMixture(n_components=4, covariance_type='full', random_state=42)
plot_gmm(gmm, X_stretched)
This makes clear that GMM addresses the two main practical issues with k-means encountered before.
这样能清楚看出GMM能解决前⾯说道的k均值的两⼤问题。
Choosing the covariance type
选择协⽅差的类型
If you look at the details of the preceding fits, you will see that the covariance_type option was set differently within
each. This hyperparameter controls the degrees of freedom in the shape of each cluster; it is essential to set this carefully
for any given problem. The default is covariance_type="diag" , which means that the size of the cluster along each
dimension can be set independently, with the resulting ellipse constrained to align with the axes. A slightly simpler and
faster model is covariance_type="spherical" , which constrains the shape of the cluster such that all dimensions
are equal. The resulting clustering will have similar characteristics to that of k-means, though it is not entirely equivalent.
A more complicated and computationally expensive model (especially as the number of dimensions grows) is to use
covariance_type="full" , which allows each cluster to be modeled as an ellipse with arbitrary orientation.
如果你仔细观察刚才的例⼦,就会发现 covariance_type 设置是不⼀样的。这个超参数控制着聚类形状的⾃由度;对于每个不同的问
题来说,⼩⼼的设置这个参数值是⾮常重要的。 covariance_type 的默认值是 'diag' ,这表⽰聚类形状的每个维度尺⼨都可以独⽴
的取值,结果就是聚类的形状是沿着坐标轴伸展的椭圆形。还有⼀个更简单和快速的模型是 covariance_type="spherical" ,这样
产⽣的结果是聚类的形状是⼀个圆形,在每个坐标轴上的尺⼨都是相等的,因此它的效果与k均值相似,虽然两者有⼀定差别。还有⼀个更
复杂和计算量⼤的模型(尤其是当数据的维度增加时)是 covariance_type="full" ,产⽣的结果是每个聚类都是⼀个任意⽅向伸展
的椭圆,其中每个维度的尺⼨都是独⽴取值的。
We can see a visual representation of these three choices for a single cluster within the following figure:
下⾯的图中我们可以看到三种不同取值的区别:
(Covariance Type)
附录中⽣成图像的代码
GMM as Density Estimation
使⽤⾼斯混合模型作为密度估计
Though GMM is often categorized as a clustering algorithm, fundamentally it is an algorithm for density estimation. That
is to say, the result of a GMM fit to some data is technically not a clustering model, but a generative probabilistic model
describing the distribution of the data.
虽然GMM经常被归为聚类算法,但它本质是⼀个密度估计算法。这就是说,GMM在⼀些数据上拟合的结果在技术上来说不是聚合模型,
⽽是⼀个⽣成概率模型,⽤来描述数据的分布。
As an example, consider some data generated from Scikit-Learn's make_moons function, which we saw in In Depth: KMeans Clustering:
下⾯使⽤在深⼊:K均值聚类⼀节中使⽤Scikit-Learn的 make_moons 函数⽣成的数据作为例⼦:
In [13]: from sklearn.datasets import make_moons
Xmoon, ymoon = make_moons(200, noise=.05, random_state=0)
plt.scatter(Xmoon[:, 0], Xmoon[:, 1]);
If we try to fit this with a two-component GMM viewed as a clustering model, the results are not particularly useful:
如果我们尝试使⽤两个成分的GMM模型来拟合它时,产⽣的结果并没有意义:
In [14]: gmm2 = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
plot_gmm(gmm2, Xmoon)
But if we instead use many more components and ignore the cluster labels, we find a fit that is much closer to the input
data:
但是如果使⽤更多的聚类数量并忽略聚类产⽣的标签的话,我们会得到⼀个更接近输⼊数据的模型结果:
In [15]: gmm16 = GaussianMixture(n_components=16, covariance_type='full', random_state=0)
plot_gmm(gmm16, Xmoon, label=False)
Here the mixture of 16 Gaussians serves not to find separated clusters of data, but rather to model the overall distribution
of the input data. This is a generative model of the distribution, meaning that the GMM gives us the recipe to generate
new random data distributed similarly to our input. For example, here are 400 new points drawn from this 16-component
GMM fit to our original data:
上⾯的例⼦并不是将数据分为16个⾼斯混合模型聚类,⽽是产⽣了数据的分布模型。这就是分布⽣成模型,代表着GMM为我们提供了按照
数据分布情况产⽣新的随机数据的⽅法。例如,下⾯是使⽤这个GMM模型产⽣400个新数据点的⽅法,它们将符合原始数据的分布情况:
译者注:GaussianMixture的sample⽅法签名和返回值都发⽣了变化,以下代码做了相应调整。
In [16]: Xnew, _ = gmm16.sample(400)
plt.scatter(Xnew[:, 0], Xnew[:, 1]);
GMM is convenient as a flexible means of modeling an arbitrary multi-dimensional distribution of data.
可以作为产⽣任意多维度分布数据的⽣成模型的⽅便⼯具。
GMM
How many components?
需要多少成分?
The fact that GMM is a generative model gives us a natural means of determining the optimal number of components for
a given dataset. A generative model is inherently a probability distribution for the dataset, and so we can simply evaluate
the likelihood of the data under the model, using cross-validation to avoid over-fitting. Another means of correcting for
over-fitting is to adjust the model likelihoods using some analytic criterion such as the Akaike information criterion (AIC)
or the Bayesian information criterion (BIC). Scikit-Learn's GMM estimator actually includes built-in methods that compute
both of these, and so it is very easy to operate on this approach.
是⼀个⽣成模型的事实提供给我们⼀个⾃然的⽅式,来获得对于给定的数据集所需的成分数量。⼀个⽣成模型实际上是数据集的概率
分布,因此我们可以在模型下计算数据的似然度,然后使⽤交叉验证来避免过拟合。另外两个解决过拟合的⽅法是使⽤分析标准来调整模
型的似然度,例如⾚池信息量准则(AIC)和⻉叶斯信息量准则(BIC)。Scikit-Learn的 GaussianMixture 评估器包含了內建的⽅法能够计
算上述准则,⾮常⽅便使⽤。
GMM
Let's look at the AIC and BIC as a function as the number of GMM components for our moon dataset:
让我们看看在上⾯的数据集中,AIC和BIC随着GMM成分数量变化的情况:
In [17]: n_components = np.arange(1, 21)
models = [GaussianMixture(n, covariance_type='full', random_state=0).fit(Xmoon)
for n in n_components]
plt.plot(n_components, [m.bic(Xmoon) for m in models], label='BIC')
plt.plot(n_components, [m.aic(Xmoon) for m in models], label='AIC')
plt.legend(loc='best')
plt.xlabel('n_components');
The optimal number of clusters is the value that minimizes the AIC or BIC, depending on which approximation we wish to
use. The AIC tells us that our choice of 16 components above was probably too many: around 8-12 components would
have been a better choice. As is typical with this sort of problem, the BIC recommends a simpler model.
成分数量的最优选择是能最⼩化AIC和BIC的值,取决于我们希望应⽤哪种近似。上图中AIC告诉我们16个成分也许太多了:区间8-12
的成分值会是更好的选择。在通常的情况下,BIC标准会推荐⼀个更简单的模型。
GMM
Notice the important point: this choice of number of components measures how well GMM works as a density estimator,
not how well it works as a clustering algorithm. I'd encourage you to think of GMM primarily as a density estimator, and
use it for clustering only when warranted within simple datasets.
注意很重要的⼀点:这个成分数量的选择只是衡量GMM作为密度估计模型的表现,不是它作为聚类算法的表现。作者更⿎励你将GMM主
要看成是密度估计,仅在简单数据集中才放⼼的将它⽤作聚类算法。
Example: GMM for Generating New Data
例⼦:⽤GMM来⽣成新数据
We just saw a simple example of using GMM as a generative model of data in order to create new samples from the
distribution defined by the input data. Here we will run with this idea and generate new handwritten digits from the
standard digits corpus that we have used before.
我们刚才看到了⼀个⽤GMM作为⽣成模型的简单例⼦,模型能从输⼊数据的分布中产⽣新的样本。下⾯我们将使⽤这个思想来⽣成新的⼿
写数字,使⽤的输⼊训练数据集是我们前⾯很熟悉的⼿写数字数据集。
To start with, let's load the digits data using Scikit-Learn's data tools:
⾸先使⽤Scikit-Learn的数据⼯具载⼊⼿写数字数据集:
In [18]: from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shape
Out[18]: (1797, 64)
Next let's plot the first 100 of these to recall exactly what we're looking at:
然后查看前⾯的100个样本:
In [19]: def plot_digits(data):
fig, ax = plt.subplots(10, 10, figsize=(8, 8),
subplot_kw=dict(xticks=[], yticks=[]))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for i, axi in enumerate(ax.flat):
im = axi.imshow(data[i].reshape(8, 8), cmap='binary')
im.set_clim(0, 16)
plot_digits(digits.data)
We have nearly 1,800 digits in 64 dimensions, and we can build a GMM on top of these to generate more. GMMs can
have difficulty converging in such a high dimensional space, so we will start with an invertible dimensionality reduction
algorithm on the data. Here we will use a straightforward PCA, asking it to preserve 99% of the variance in the projected
data:
数据集中有接近1800个数字,每个数字都有64个维度,我们来构建⼀个GMM模型并且⽣成更多的⼿写数字。GMM在如此⾼维度空间中可
能很难收敛,因此⾸先我们使⽤⼀个可能的降维算法来降低数据集的维度。下⾯我们直接使⽤PCA,要求它在降维后保留99%的可解释⽅
差:
In [20]: from sklearn.decomposition import PCA
pca = PCA(0.99, whiten=True)
data = pca.fit_transform(digits.data)
data.shape
Out[20]: (1797, 41)
The result is 41 dimensions, a reduction of nearly 1/3 with almost no information loss. Given this projected data, let's use
the AIC to get a gauge for the number of GMM components we should use:
结果有42个维度,接近降低了1/3的维度,但是基本上没有信息的损失。在这个数据上,我们使⽤AIC来测算我们在GMM模型中需要使⽤的
成分数量:
In [21]: n_components = np.arange(50, 210, 10)
models = [GaussianMixture(n, covariance_type='full', random_state=0)
for n in n_components]
aics = [model.fit(data).aic(data) for model in models]
plt.plot(n_components, aics);
It appears that around 110 components minimizes the AIC; we will use this model. Let's quickly fit this to the data and
confirm that it has converged:
看起来在150附近AIC具有最⼩值;让我们使⽤这个成分之来拟合数据,并查看模型收敛状态:
译者注:译者运⾏结果为150个成分最优,因此下⾯代码也做了修改。
In [22]: gmm = GaussianMixture(150, covariance_type='full', random_state=0)
gmm.fit(data)
print(gmm.converged_)
True
Now we can draw samples of 100 new points within this 41-dimensional projected space, using the GMM as a generative
model:
现在们可以使⽤这个GMM作为⽣成模型,在这个41维空间中创建100个新的数据点:
In [23]: data_new, _ = gmm.sample(100)
data_new.shape
Out[23]: (100, 41)
Finally, we can use the inverse transform of the PCA object to construct the new digits:
最后,我们使⽤PCA的逆向转换重新构建这100个⼿写数字:
In [24]: digits_new = pca.inverse_transform(data_new)
plot_digits(digits_new)
The results for the most part look like plausible digits from the dataset!
我们看到结果中⼤多数的数字都很符合数据集的情况。
Consider what we've done here: given a sampling of handwritten digits, we have modeled the distribution of that data in
such a way that we can generate brand new samples of digits from the data: these are "handwritten digits" which do not
individually appear in the original dataset, but rather capture the general features of the input data as modeled by the
mixture model. Such a generative model of digits can prove very useful as a component of a Bayesian generative
classifier, as we shall see in the next section.
再次考虑这个过程:给定⼀些⼿写数字的样本,根据数据的分布构建了模型,然后我们使⽤这个模型⽣成全新的数字样本:这些“⼿写数
字”并不是在原始数据集中的,⽽是通过捕捉到输⼊数据的主要特征的混合模型⽣成的。这样的数字⽣成模型作为⻉叶斯⽣成分类器的组成
部分会⾮常有⽤,我们将会在下⼀节看到。
<
深⼊:k-均值聚类 | ⽬录 | 深⼊:核密度估计 >
Open in Colab
<
深⼊:⾼斯混合模型 | ⽬录 | 应⽤:脸部识别管道 >
Open in Colab
In-Depth: Kernel Density Estimation
深⼊:核密度估计
In the previous section we covered Gaussian mixture models (GMM), which are a kind of hybrid between a clustering
estimator and a density estimator. Recall that a density estimator is an algorithm which takes a D-dimensional dataset
and produces an estimate of the D-dimensional probability distribution which that data is drawn from. The GMM
algorithm accomplishes this by representing the density as a weighted sum of Gaussian distributions. Kernel density
estimation (KDE) is in some senses an algorithm which takes the mixture-of-Gaussians idea to its logical extreme: it uses
a mixture consisting of one Gaussian component per point, resulting in an essentially non-parametric estimator of density.
In this section, we will explore the motivation and uses of KDE.
在上⼀节中我们介绍了⾼斯混合模型(GMM),它是⼀种介于聚类评估器和密度评估器的混合模型。回忆⼀下密度评估器的定义,这是⼀
种从 维数据集中产⽣⼀个 维概率分布的算法。GMM算法使⽤了加权⾼斯分布和的⽅式实现了密度评估器。核密度估计在某种程度上是
⼀个将⾼斯混合理念发展到其逻辑层次的算法:其中包含了每个数据点形成的⼀个⾼斯成分,最终得到⼀个基本上⽆参数的密度评估器。
在本节中,我们会讨论核密度分析KDE的原理和应⽤。
D
D
We begin with the standard imports:
导⼊包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
Motivating KDE: Histograms
初探KDE:直⽅图
As already discussed, a density estimator is an algorithm which seeks to model the probability distribution that generated
a dataset. For one dimensional data, you are probably already familiar with one simple density estimator: the histogram.
A histogram divides the data into discrete bins, counts the number of points that fall in each bin, and then visualizes the
results in an intuitive manner.
前⾯已经讨论过,密度评估器是⼀种找到样本概率分布的模型,然后⽤来⽣成数据集的算法。对于⼀维数据⽽⾔,你应该已经熟悉其中⼀
种简单的密度评估器:直⽅图。直⽅图将数据分成离散的桶,计算每个桶中数据点的数量,然后将结果可视化成⼀张⾮常直观的图表。
For example, let's create some data that is drawn from two normal distributions:
下⾯我们构建⼀些数据形成两个正态分布:
In [2]: def make_data(N, f=0.3, rseed=1):
rand = np.random.RandomState(rseed)
x = rand.randn(N)
x[int(f * N):] += 5
return x
x = make_data(1000)
We have previously seen that the standard count-based histogram can be created with the plt.hist() function. By
specifying the normed parameter of the histogram, we end up with a normalized histogram where the height of the bins
does not reflect counts, but instead reflects probability density:
前⾯我们已经看到标准的直⽅图可以使⽤ plt.hist() 函数绘制。通过设置 density 参数,我们可以将直⽅图标准化,这时图像的⾼度
不再代表数据点的数量,⽽是概率密度:
译者注:新版Matplotlib已经不再使⽤normed参数,原⽂和代码中的参数名称已经修改为density。
In [3]: hist = plt.hist(x, bins=30, density=True)
Notice that for equal binning, this normalization simply changes the scale on the y-axis, leaving the relative heights
essentially the same as in a histogram built from counts. This normalization is chosen so that the total area under the
histogram is equal to 1, as we can confirm by looking at the output of the histogram function:
注意对上图来说,标准化只是修改了y轴的度量,但是每个桶相对⾼度与使⽤简单求和构建的直⽅图是⼀致的。标准化能够使得直⽅图的全
部⾯积加起来等于1,可以通过直⽅图函数的输出结果来进⾏验证:
In [4]: density, bins, patches = hist
widths = bins[1:] - bins[:-1]
(density * widths).sum()
Out[4]: 1.0
One of the issues with using a histogram as a density estimator is that the choice of bin size and location can lead to
representations that have qualitatively different features. For example, if we look at a version of this data with only 20
points, the choice of how to draw the bins can lead to an entirely different interpretation of the data! Consider this
example:
使⽤直⽅图作为密度评估器的⼀个问题是桶⼤⼩和位置的选择会导致展现出不同的数据特征。例如我们仅仅使⽤20个数据点的情况下,不
同的选择会得到完全不同的数据解释。如下例:
In [5]: x = make_data(20)
bins = np.linspace(-5, 10, 10)
In [6]: fig, ax = plt.subplots(1, 2, figsize=(12, 4),
sharex=True, sharey=True,
subplot_kw={'xlim':(-4, 9),
'ylim':(-0.02, 0.3)})
fig.subplots_adjust(wspace=0.05)
for i, offset in enumerate([0.0, 0.6]):
ax[i].hist(x, bins=bins + offset, density=True)
ax[i].plot(x, np.full_like(x, -0.01), '|k',
markeredgewidth=1)
On the left, the histogram makes clear that this is a bimodal distribution. On the right, we see a unimodal distribution with
a long tail. Without seeing the preceding code, you would probably not guess that these two histograms were built from
the same data: with that in mind, how can you trust the intuition that histograms confer? And how might we improve on
this?
左边的直⽅图很明显是⼀个双峰分布。右边的直⽅图却是⼀个单峰分布。如果不是看到了前⾯的代码,我们可能会猜测这两个直⽅图是从
不同的数据集获得的:在这种情况下,如何能信任直⽅图给我们关于数据分布的直觉呢?该如何改进这点呢?
Stepping back, we can think of a histogram as a stack of blocks, where we stack one block within each bin on top of each
point in the dataset. Let's view this directly:
再回头深⼊考虑⼀下,我们可以将直⽅图想象成⽅块组成的堆,将数据集中的每个数据点都作为⼀个⽅块放置到其从属的桶的最上⽅。我
们来看看:
In [7]: fig, ax = plt.subplots()
bins = np.arange(-3, 8)
ax.plot(x, np.full_like(x, -0.1), '|k',
markeredgewidth=1)
for count, edge in zip(*np.histogram(x, bins)):
for i in range(count):
ax.add_patch(plt.Rectangle((edge, i), 1, 1,
alpha=0.5))
ax.set_xlim(-4, 8)
ax.set_ylim(-0.2, 8);
The problem with our two binnings stems from the fact that the height of the block stack often reflects not on the actual
density of points nearby, but on coincidences of how the bins align with the data points. This mis-alignment between
points and their blocks is a potential cause of the poor histogram results seen here. But what if, instead of stacking the
blocks aligned with the bins, we were to stack the blocks aligned with the points they represent? If we do this, the blocks
won't be aligned, but we can add their contributions at each location along the x-axis to find the result. Let's try this:
刚才看到那两个直⽅图的问题实质在于,⽅块组成的堆⾼度通常反映的不是实际的附近数据点密度,⽽是取决于桶与数据点对⻬的选择⽅
式,这具有⼀定的偶然性。不合适的选择就是我们前⾯看到不正确的直⽅图结果的原因。但是如果我们不是将⽅块叠放到桶上,⽽是将⽅
块叠放到它们所代表的数据点上会怎么样?这样做的话,这些⽅块不会对⻬,我们可以将每个数据点在x轴的每个位置上的贡献累加起来得
到结果。例如:
In [8]: x_d = np.linspace(-4, 8, 2000)
density = sum((abs(xi - x_d) < 0.5) for xi in x)
plt.fill_between(x_d, density, alpha=0.5)
plt.plot(x, np.full_like(x, -0.1), '|k', markeredgewidth=1)
plt.axis([-4, 8, -0.2, 8]);
The result looks a bit messy, but is a much more robust reflection of the actual data characteristics than is the standard
histogram. Still, the rough edges are not aesthetically pleasing, nor are they reflective of any true properties of the data.
In order to smooth them out, we might decide to replace the blocks at each location with a smooth function, like a
Gaussian. Let's use a standard normal curve at each point instead of a block:
虽然结果看起来有点乱,但是它能⽐标准直⽅图更加健壮地反映数据的特征。然⽽图中的坚硬边界很不美观,且它们也⽆法反映数据的真
实属性。我们可以考虑使⽤光滑的函数,如⾼斯函数,来平滑这个图形。下⾯我们在每个数据点上使⽤使⽤标准正态曲线来取代叠放的⽅
块:
In [9]: from scipy.stats import norm
x_d = np.linspace(-4, 8, 1000)
density = sum(norm(xi).pdf(x_d) for xi in x)
plt.fill_between(x_d, density, alpha=0.5)
plt.plot(x, np.full_like(x, -0.1), '|k', markeredgewidth=1)
plt.axis([-4, 8, -0.2, 5]);
This smoothed-out plot, with a Gaussian distribution contributed at the location of each input point, gives a much more
accurate idea of the shape of the data distribution, and one which has much less variance (i.e., changes much less in
response to differences in sampling).
平滑后的图像,在每个输⼊点上都是⾼斯分布,能够提供对于数据分布的更加精确的形状,⽽且具有更少的差异(因为取样不同产⽣的差
异⼩了许多)。
These last two plots are examples of kernel density estimation in one dimension: the first uses a so-called "tophat" kernel
and the second uses a Gaussian kernel. We'll now look at kernel density estimation in more detail.
后⾯这两张图就是核密度估计在⼀维数据上的例⼦:第⼀幅图使⽤的是“⾼帽”核,第⼆幅图使⽤的是⾼斯核。下⾯我们详细讨论核密度估
计。
Kernel Density Estimation in Practice
实践中使⽤核密度估计
The free parameters of kernel density estimation are the kernel, which specifies the shape of the distribution placed at
each point, and the kernel bandwidth, which controls the size of the kernel at each point. In practice, there are many
kernels you might use for a kernel density estimation: in particular, the Scikit-Learn KDE implementation supports one of
six kernels, which you can read about in Scikit-Learn's Density Estimation documentation.
核密度估计中的⾃由参数是核,它设定了分布在每个点的形状以及控制着每个点上核的⼤⼩(被称为核带宽)的参数。实践中有许多可⽤
的核密度估计:具体来说,Scikit-Learn的KDE实现了其中的6种,读者可以在Scikit-Learn在线⽂档密度估计中查看。
While there are several versions of kernel density estimation implemented in Python (notably in the SciPy and
StatsModels packages), I prefer to use Scikit-Learn's version because of its efficiency and flexibility. It is implemented in
the sklearn.neighbors.KernelDensity estimator, which handles KDE in multiple dimensions with one of six
kernels and one of a couple dozen distance metrics. Because KDE can be fairly computationally intensive, the ScikitLearn estimator uses a tree-based algorithm under the hood and can trade off computation time for accuracy using the
atol (absolute tolerance) and rtol (relative tolerance) parameters. The kernel bandwidth, which is a free parameter,
can be determined using Scikit-Learn's standard cross validation tools as we will soon see.
虽然Python当中有⼀些核密度估计的实现(主要是在SciPy和StatsModels包中),作者还是建议使⽤Scikit-Learn的版本,因为它具有⾼效
和灵活的特性。这些评估器被实现在 sklearn.neighbors.KernelDensity 当中,它们能使⽤6种核类型以及数⼗种距离度量计算⽅
法在多维数据中实现KDE。因为KDE⽅法较为计算密集,Scikit-Learn的评估器在底层使⽤了树形算法,并且能够使⽤ atol (绝对容差)
和 rtol (相对容差)来平衡计算时间与精确度。其中的⾃由参数核带宽可以使⽤标准的交叉验证⼯具决定,我们⻢上就会看到。
Let's first show a simple example of replicating the above plot using the Scikit-Learn KernelDensity estimator:
下⾯使⽤Scikit-Learn的 KernelDensity 评估器重复⼀下上⾯的图表,作为⼀个简单的例⼦:
In [10]: from sklearn.neighbors import KernelDensity
初始化 模型,拟合数据
#
KDE
kde = KernelDensity(bandwidth=1.0, kernel='gaussian')
kde.fit(x[:, None])
返回概率密度的对数值
# score_samples
logprob = kde.score_samples(x_d[:, None])
plt.fill_between(x_d, np.exp(logprob), alpha=0.5)
plt.plot(x, np.full_like(x, -0.01), '|k', markeredgewidth=1)
plt.ylim(-0.02, 0.22);
The result here is normalized such that the area under the curve is equal to 1.
上⾯的结果已经标准化了,因此曲线下⽅的⾯积为1。
Selecting the bandwidth via cross-validation
通过交叉验证选择带宽
The choice of bandwidth within KDE is extremely important to finding a suitable density estimate, and is the knob that
controls the bias–variance trade-off in the estimate of density: too narrow a bandwidth leads to a high-variance estimate
(i.e., over-fitting), where the presence or absence of a single point makes a large difference. Too wide a bandwidth leads
to a high-bias estimate (i.e., under-fitting) where the structure in the data is washed out by the wide kernel.
中带宽的选择对于寻找合适的密度估计是⾄关重要的,同时也是控制偏差的开关,这是密度估计⽅差的权衡值:太窄的带宽会导致⾼
⽅差估计(也就是过拟合),也就是⼀个数据点的存在或缺失会导致巨⼤的差异。太宽泛的带宽会导致⾼偏差估计(也就是⽋拟合),整
个数据的结构被过宽的核给抹平了。
KDE
There is a long history in statistics of methods to quickly estimate the best bandwidth based on rather stringent
assumptions about the data: if you look up the KDE implementations in the SciPy and StatsModels packages, for
example, you will see implementations based on some of these rules.
在统计学中,基于数据相当严格的假设来估计最佳带宽有着很⻓的历史:如果你查看SciPy和StatsModels包中的KDE实现,你可以看到其
中⼀些规则的实现。
In machine learning contexts, we've seen that such hyperparameter tuning often is done empirically via a cross-validation
approach. With this in mind, the KernelDensity estimator in Scikit-Learn is designed such that it can be used directly
within the Scikit-Learn's standard grid search tools. Here we will use GridSearchCV to optimize the bandwidth for the
preceding dataset. Because we are looking at such a small dataset, we will use leave-one-out cross-validation, which
minimizes the reduction in training set size for each cross-validation trial:
在机器学习领域,我们已经知道这样的超参数调整通常可以通过交叉验证⽅法来实现。因此Scikit-Learn中的 KernelDensity 评估器被
设计成可以直接使⽤Scikit-Learn的标准⽹格搜索⼯具。这⾥我们将使⽤ GridSearchCV 来对前⾯的数据集的带宽进⾏优化。因为这是⼀
个⾮常⼩的数据集,我们会使⽤留出⼀个的交叉验证⽅法,这能在每次交叉验证测试中尽量保证训练集的最⼤样本量:
译者注:新版Scikit-Learn已经将GridSearchCV和LeaveOneOut移到了 sklearn.model_selection 包中,并且LeaveOneOut不再需
要提供参数。下⾯的代码做了相应修改。
In [11]: from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import LeaveOneOut
bandwidths = 10 ** np.linspace(-1, 1, 100)
grid = GridSearchCV(KernelDensity(kernel='gaussian'),
{'bandwidth': bandwidths},
cv=LeaveOneOut())
grid.fit(x[:, None]);
Now we can find the choice of bandwidth which maximizes the score (which in this case defaults to the log-likelihood):
然后就可以得到最⼤分值的带宽了(本例中分值默认为对数分值):
In [12]: grid.best_params_
Out[12]: {'bandwidth': 1.1233240329780276}
The optimal bandwidth happens to be very close to what we used in the example plot earlier, where the bandwidth was
1.0 (i.e., the default width of scipy.stats.norm ).
带宽的最优值正好⾮常接近我们在前⾯例⼦中使⽤的1.0(也是 scipy.stats.norm 的默认宽度)。
Example: KDE on a Sphere
例⼦:球⾯上的KDE
Perhaps the most common use of KDE is in graphically representing distributions of points. For example, in the Seaborn
visualization library (see Visualization With Seaborn), KDE is built in and automatically used to help visualize points in
one and two dimensions.
最常⻅的应⽤可能是数据点分布的图像表⽰。例如在Seaborn可视化库(参⻅使⽤Seaborn可视化)中,KDE是在⼀维和⼆维空间中
的內建的⾃动化可视化⽅法。
KDE
Here we will look at a slightly more sophisticated use of KDE for visualization of distributions. We will make use of some
geographic data that can be loaded with Scikit-Learn: the geographic distributions of recorded observations of two South
American mammals, Bradypus variegatus (the Brown-throated Sloth) and Microryzomys minutus (the Forest Small Rice
Rat).
这⾥我们将要讨论⼀个稍微复杂⼀些的KDE进⾏数据分布可视化的例⼦:观察记录到两种南美哺乳动物的地理分布情况,棕喉树懒和森林
⼩稻⿏。
With Scikit-Learn, we can fetch this data as follows:
使⽤Scikit-Learn如下获取数据:
In [13]: from sklearn.datasets import fetch_species_distributions
data = fetch_species_distributions()
提取物种 和位置数据
#
ID
latlon = np.vstack([data.train['dd lat'],
data.train['dd long']]).T
species = np.array([d.decode('ascii').startswith('micro')
for d in data.train['species']], dtype='int')
With this data loaded, we can use the Basemap toolkit (mentioned previously in Geographic Data with Basemap) to plot
the observed locations of these two species on the map of South America.
当数据载⼊后,我们可以使⽤Basemap⼯具集(之前在使⽤Basemap创建地理位置图表中介绍过)来绘制这两个物种在南美洲地图上观测
的位置。
译者注:译者所⽤Scikit-Learn版本有个issue,警告species_distributions模块已经过时,master分⽀已经修复,但未并⼊发⾏版,此处保
留了该警告,但不影响后续功能
In [14]: from mpl_toolkits.basemap import Basemap
from sklearn.datasets.species_distributions import construct_grids
xgrid, ygrid = construct_grids(data)
使⽤
绘制海岸线
#
basemap
m = Basemap(projection='cyl', resolution='c',
llcrnrlat=ygrid.min(), urcrnrlat=ygrid.max(),
llcrnrlon=xgrid.min(), urcrnrlon=xgrid.max())
m.drawmapboundary(fill_color='#DDEEFF')
m.fillcontinents(color='#FFEEDD')
m.drawcoastlines(color='gray', zorder=2)
m.drawcountries(color='gray', zorder=2)
绘制位置
#
m.scatter(latlon[:, 1], latlon[:, 0], zorder=3,
c=species, cmap='rainbow', latlon=True);
/home/wangy/anaconda3/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: Th
e sklearn.datasets.species_distributions module is deprecated in version 0.22 and will be removed in
version 0.24. The corresponding classes / functions should instead be imported from sklearn.datasets.
Anything that cannot be imported from sklearn.datasets is now part of the private API.
warnings.warn(message, FutureWarning)
Unfortunately, this doesn't give a very good idea of the density of the species, because points in the species range may
overlap one another. You may not realize it by looking at this plot, but there are over 1,600 points shown here!
不过,上图并没有给出这两种动物的分布密度估计,因为这些点的范围互相重叠了。上图中有很多分布点,实际上有超过1600个数据点绘
制在图中。
Let's use kernel density estimation to show this distribution in a more interpretable way: as a smooth indication of density
on the map. Because the coordinate system here lies on a spherical surface rather than a flat plane, we will use the
haversine distance metric, which will correctly represent distances on a curved surface.
让我们使⽤核密度估计将这个分布展⽰成更加有含义的图表:在地图上显⽰平滑的密度分布情况。因为实际上使⽤的是球⾯坐标系统⽽不
是平⾯坐标系,所以距离度量采取了 haversine ,这是⼀个能正确表达曲⾯距离的⽅法。
There is a bit of boilerplate code here (one of the disadvantages of the Basemap toolkit) but the meaning of each code
block should be clear:
下⾯的代码有点冗⻓(Basemap⼯具集的缺点之⼀),但是每个代码块的含义还是很清晰的:
设置地图上的数据⽹格
In [15]: #
X, Y = np.meshgrid(xgrid[::5], ygrid[::5][::-1])
land_reference = data.coverages[6][::5, ::5]
land_mask = (land_reference > -9999).ravel()
xy = np.vstack([Y.ravel(), X.ravel()]).T
xy = np.radians(xy[land_mask])
创建两个并排的图表
#
fig, ax = plt.subplots(1, 2)
fig.subplots_adjust(left=0.05, right=0.95, wspace=0.05)
species_names = ['Bradypus Variegatus', 'Microryzomys Minutus']
cmaps = ['Purples', 'Reds']
for i, axi in enumerate(ax):
axi.set_title(species_names[i])
使⽤
绘制海岸线和国境线
#
basemap
m = Basemap(projection='cyl', llcrnrlat=Y.min(),
urcrnrlat=Y.max(), llcrnrlon=X.min(),
urcrnrlon=X.max(), resolution='c', ax=axi)
m.drawmapboundary(fill_color='#DDEEFF')
m.drawcoastlines()
m.drawcountries()
构造分布的球⾯核密度估计
#
kde = KernelDensity(bandwidth=0.03, metric='haversine')
kde.fit(np.radians(latlon[species == i]))
仅计算陆地范围,
代表海洋
#
-9999
Z = np.full(land_mask.shape[0], -9999.0)
Z[land_mask] = np.exp(kde.score_samples(xy))
Z = Z.reshape(X.shape)
绘制密度的轮廓
#
levels = np.linspace(0, Z.max(), 25)
axi.contourf(X, Y, Z, levels=levels, cmap=cmaps[i])
Compared to the simple scatter plot we initially used, this visualization paints a much clearer picture of the geographical
distribution of observations of these two species.
对⽐前⾯我们绘制的简单散点图,上⾯两个图表很清晰的展⽰了两种动物的地理位置分布情况。
Example: Not-So-Naive Bayes
例⼦:⾮朴素⻉叶斯
This example looks at Bayesian generative classification with KDE, and demonstrates how to use the Scikit-Learn
architecture to create a custom estimator.
下⾯这个例⼦我们来看下使⽤KDE创建⻉叶斯⽣成分类,并且展⽰如何使⽤Scikit-Learn创建⾃定义的评估器。
In In Depth: Naive Bayes Classification, we took a look at naive Bayesian classification, in which we created a simple
generative model for each class, and used these models to build a fast classifier. For Gaussian naive Bayes, the
generative model is a simple axis-aligned Gaussian. With a density estimation algorithm like KDE, we can remove the
"naive" element and perform the same classification with a more sophisticated generative model for each class. It's still
Bayesian classification, but it's no longer naive.
在深⼊:朴素⻉叶斯分类中,我们学习了朴素⻉叶斯分类,⾥⾯构建了每个类别的简单⽣成模型并且使⽤这些模型来构建⼀个快速分类
器。对于⾼斯朴素⻉叶斯来说,⽣成模型就是简单的沿着坐标轴的⾼斯函数。使⽤密度估计算法如KDE,我们可以去除其中的“朴素”成
分,然后对每个类别使⽤更加复杂的⽣成模型进⾏相同的分类⼯作。这仍然是⻉叶斯分类,只是不再朴素。
The general approach for generative classification is this:
1. Split the training data by label.
2. For each set, fit a KDE to obtain a generative model of the data. This allows you for any observation x and label y to
compute a likelihood P (x | y) .
3. From the number of examples of each class in the training set, compute the class prior, P (y) .
4. For an unknown point x , the posterior probability for each class is P (y | x) ∝ P (x | y)P (y) . The class which
maximizes this posterior is the label assigned to the point.
⽣成分类的通⽤⽅法如下:
1. 将训练数据依据标签划分成不同类别。
2. 对每个类别,使⽤KDE拟合数据获得⼀个⽣成模型。这允许你对于任何观察 和标签 计算出似然
。
3. 对训练集中的每个类别,从样本数量计算得到类别先验概率
。
4. 对⼀个未知点 ,每个类别的后验概率是
。哪个类别具有最⼤的后验概率值,就将这个点设置为该类别标
签。
x
y
P (x | y)
P (y)
x
P (y | x) ∝ P (x | y)P (y)
The algorithm is straightforward and intuitive to understand; the more difficult piece is couching it within the Scikit-Learn
framework in order to make use of the grid search and cross-validation architecture.
上述算法很直观和易于理解;更困难的部分是将它实现在Scikit-Learn框架当中,这样就能使⽤⽹格搜索和交叉验证⼯具。
This is the code that implements the algorithm within the Scikit-Learn framework; we will step through it following the
code block:
下⾯是Scikit-Learn框架中实现这个算法的代码;我们过⼀遍这些代码⽚段:
In [16]: from sklearn.base import BaseEstimator, ClassifierMixin
class KDEClassifier(BaseEstimator, ClassifierMixin):
"""Bayesian generative classification based on KDE
Parameters
---------bandwidth : float
the kernel bandwidth within each class
kernel : str
the kernel name, passed to KernelDensity
"""
def __init__(self, bandwidth=1.0, kernel='gaussian'):
self.bandwidth = bandwidth
self.kernel = kernel
def fit(self, X, y):
self.classes_ = np.sort(np.unique(y))
training_sets = [X[y == yi] for yi in self.classes_]
self.models_ = [KernelDensity(bandwidth=self.bandwidth,
kernel=self.kernel).fit(Xi)
for Xi in training_sets]
self.logpriors_ = [np.log(Xi.shape[0] / X.shape[0])
for Xi in training_sets]
return self
def predict_proba(self, X):
logprobs = np.array([model.score_samples(X)
for model in self.models_]).T
result = np.exp(logprobs + self.logpriors_)
return result / result.sum(1, keepdims=True)
def predict(self, X):
return self.classes_[np.argmax(self.predict_proba(X), 1)]
The anatomy of a custom estimator
⾃定义评估器代码剖析
Let's step through this code and discuss the essential features:
让我们⼀步⼀步的分析上⾯的代码并讨论其中最关键的特性:
from sklearn.base import BaseEstimator, ClassifierMixin
class KDEClassifier(BaseEstimator, ClassifierMixin):
"""Bayesian generative classification based on KDE
Parameters
---------bandwidth : float
the kernel bandwidth within each class
kernel : str
the kernel name, passed to KernelDensity
"""
Each estimator in Scikit-Learn is a class, and it is most convenient for this class to inherit from the BaseEstimator
class as well as the appropriate mixin, which provides standard functionality. For example, among other things, here the
BaseEstimator contains the logic necessary to clone/copy an estimator for use in a cross-validation procedure, and
ClassifierMixin defines a default score() method used by such routines. We also provide a doc string, which
will be captured by IPython's help functionality (see Help and Documentation in IPython).
中的每个评估器都是⼀个类(译者注:Python类),对于评估器类来说最⽅便的就是继承 BaseEstimator 类以及相应的混
合器,它们能提供标准的功能。例如这⾥ BaseEstimator 包含着代码逻辑当需要使⽤交叉验证过程时能复制评估器的副本,
ClassifierMixin 定义了默认的 score() ⽅法给分类器继承。下⾯是⽂档字符串,可以被IPython的帮助功能捕获到(参⻅IPython的
帮助和⽂档)。
Scikit-Learn
Next comes the class initialization method:
下⾯是类实例初始化⽅法:
def __init__(self, bandwidth=1.0, kernel='gaussian'):
self.bandwidth = bandwidth
self.kernel = kernel
This is the actual code that is executed when the object is instantiated with KDEClassifier() . In Scikit-Learn, it is
important that initialization contains no operations other than assigning the passed values by name to self . This is due
to the logic contained in BaseEstimator required for cloning and modifying estimators for cross-validation, grid
search, and other functions. Similarly, all arguments to __init__ should be explicit: i.e. *args or **kwargs
should be avoided, as they will not be correctly handled within cross-validation routines.
这个⽅法的代码是当对象实例通过 KDEClassifier() 创建完成后初始化执⾏的部分。在Scikit-Learn中,很重要的⼀点需要记住,初始
化⽅法除了通过 self 设置对象属性外不能包括其他的操作。这是因为 BaseEstimator 中的代码逻辑在交叉验证、⽹格搜索和其他功能
时需要克隆和修改评估器。类似的, __init__ ⽅法的参数应该是显式定义的: *args 或 **kwargs 的定义⽅式应该避免,同样是因
为它们⽆法被交叉验证过程正确的处理。
Next comes the fit() method, where we handle training data:
接下来是 fit() ⽅法,对训练数据进⾏拟合:
def fit(self, X, y):
self.classes_ = np.sort(np.unique(y))
training_sets = [X[y == yi] for yi in self.classes_]
self.models_ = [KernelDensity(bandwidth=self.bandwidth,
kernel=self.kernel).fit(Xi)
for Xi in training_sets]
self.logpriors_ = [np.log(Xi.shape[0] / X.shape[0])
for Xi in training_sets]
return self
Here we find the unique classes in the training data, train a KernelDensity model for each class, and compute the
class priors based on the number of input samples. Finally, fit() should always return self so that we can chain
commands. For example:
⾸先找出训练数据中所有唯⼀的分类标签,对每个分类独⽴训练⼀个 KernelDensity 模型,然后根据输⼊样本数量计算每个分类的先验
概率。最后 fit() ⽅法应该永远返回 self 令其⽀持链式操作。例如:
label = model.fit(X, y).predict(X)
Notice that each persistent result of the fit is stored with a trailing underscore (e.g., self.logpriors_ ). This is a
convention used in Scikit-Learn so that you can quickly scan the members of an estimator (using IPython's tab
completion) and see exactly which members are fit to training data.
注意⼀下 fit 得到的持久化结果应该保存在后缀下划线名称的属性当中(例如 self.logpriors_ )。这是Scikit-Learn的编码规范⽅便
⽤⼾迅速的查看评估器的成员值(使⽤IPython的制表符补全)并获得已经拟合到训练数据上的成员变量值。
Finally, we have the logic for predicting labels on new data:
最后我们看到的是在新数据上预测标签的逻辑:
def predict_proba(self, X):
logprobs = np.vstack([model.score_samples(X)
for model in self.models_]).T
result = np.exp(logprobs + self.logpriors_)
return result / result.sum(1, keepdims=True)
def predict(self, X):
return self.classes_[np.argmax(self.predict_proba(X), 1)]
Because this is a probabilistic classifier, we first implement predict_proba() which returns an array of class
probabilities of shape [n_samples, n_classes] . Entry [i, j] of this array is the posterior probability that
sample i is a member of class j , computed by multiplying the likelihood by the class prior and normalizing.
因为这是⼀个概率分类器,我们⾸先实现了 predict_proba() ⽅法,它返回新数据在每个分类上的后验概率数组,形状是
[n_samples, n_classes] 。数组中的元素 [i, j] 是样本 i 属于分类 j 的后验概率值,通过将似然值与分类先验概率值相乘并标
准化后得到。
Finally, the predict() method uses these probabilities and simply returns the class with the largest probability.
最后 predict() ⽅法使⽤这些概率并在其中找到最⼤值,然后返回分类的标签。
Using our custom estimator
使⽤我们的⾃定义评估器
Let's try this custom estimator on a problem we have seen before: the classification of hand-written digits. Here we will
load the digits, and compute the cross-validation score for a range of candidate bandwidths using the GridSearchCV
meta-estimator (refer back to Hyperparameters and Model Validation):
下⾯让我们试⼀下这个⾃定义评估器,使⽤前⾯我们研究过的问题:⼿写数字分类。我们载⼊⼿写数字数据,然后针对⼀定范围的带宽值
使⽤ GridSearchCV 元评估器计算交叉验证结果(参⻅超参数和模型验证):
译者注:下⾯代码做了修改以适应新版本Scikit-Learn。包括GridSearchCV从属的包,参数cv和结果中使⽤cv_result_字典取分值。
In [17]: from sklearn.datasets import load_digits
from sklearn.model_selection import GridSearchCV
digits = load_digits()
bandwidths = 10 ** np.linspace(0, 2, 100)
grid = GridSearchCV(KDEClassifier(), {'bandwidth': bandwidths}, cv=5)
grid.fit(digits.data, digits.target)
scores = grid.cv_results_['mean_test_score']
Next we can plot the cross-validation score as a function of bandwidth:
接下来我们可以绘制交叉验证分值与带宽之间的函数图像:
In [18]: plt.semilogx(bandwidths, scores)
plt.xlabel('bandwidth')
plt.ylabel('accuracy')
plt.title('KDE Model Performance')
print(grid.best_params_)
print('accuracy =', grid.best_score_)
{'bandwidth': 6.135907273413174}
accuracy = 0.9677298050139276
We see that this not-so-naive Bayesian classifier reaches a cross-validation accuracy of just over 96%; this is compared
to around 80% for the naive Bayesian classification:
我们看到这个不那么朴素的⻉叶斯分类器达到了交叉验证准确率超过96%;⽽朴素⻉叶斯分类只有⼤约80%:
In [19]: from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
cross_val_score(GaussianNB(), digits.data, digits.target, cv=5).mean()
Out[19]: 0.8069281956050759
One benefit of such a generative classifier is interpretability of results: for each unknown sample, we not only get a
probabilistic classification, but a full model of the distribution of points we are comparing it to! If desired, this offers an
intuitive window into the reasons for a particular classification that algorithms like SVMs and random forests tend to
obscure.
上述⽣成分类器的⼀个优点是结果的可解释性:对于每个未知的样本,我们不但得到了概率分类,还获得了数据点分布情况的完整模型。
如果需要的话,还可以给出⼀个样本具体分类的直观理由,⽽其他算法像SVM和随机森林在这点上通常是模糊的。
If you would like to take this further, there are some improvements that could be made to our KDE classifier model:
we could allow the bandwidth in each class to vary independently
we could optimize these bandwidths not based on their prediction score, but on the likelihood of the training data
under the generative model within each class (i.e. use the scores from KernelDensity itself rather than the
global prediction accuracy)
如果你希望更进⼀步,下⾯是我们KDE分类模型可以继续优化的⼀些建议:
可以允许每个分类使⽤独⽴的带宽值。
可以优化带宽值,不基于它们的预测分数,⽽是基于每个分类在⽣成模型下拟合训练数据的似然值(也就是使⽤ KernelDensity 本
⾝的分数值来调整带宽,⽽不是全局的预测准确率)。
Finally, if you want some practice building your own estimator, you might tackle building a similar Bayesian classifier
using Gaussian Mixture Models instead of KDE.
最后,如果你希望练习构建你⾃⼰的评估器,你可以尝试⼀个相似的⻉叶斯分类器,使⽤⾼斯混合模型⽽不是KDE。
<
深⼊:⾼斯混合模型 | ⽬录 | 应⽤:脸部识别管道 >
Open in Colab
<
深⼊:核密度估计 | ⽬录 | 更多机器学习资源 >
Open in Colab
Application: A Face Detection Pipeline
应⽤:脸部检测管道
This chapter has explored a number of the central concepts and algorithms of machine learning. But moving from these
concepts to real-world application can be a challenge. Real-world datasets are noisy and heterogeneous, may have
missing features, and data may be in a form that is difficult to map to a clean [n_samples, n_features] matrix.
Before applying any of the methods discussed here, you must first extract these features from your data: there is no
formula for how to do this that applies across all domains, and thus this is where you as a data scientist must exercise
your own intuition and expertise.
本章讨论了很多的机器学习中关键的概念和算法。但是将这些概念应⽤到真实数据仍然是⼀个挑战。真实世界数据集是充满噪声和异构
的,可能缺失特征,并且数据也可能存在于⼀种难以映射到 [n_samples, n_features] 矩阵的结构中。在应⽤任何前⾯介绍的⽅法之
前,你必须⾸先从数据中提取这些特征:这个步骤是没有公式,没有万能做法的,因此这是你作为数据科学家必须培养的直觉和经验。
One interesting and compelling application of machine learning is to images, and we have already seen a few examples
of this where pixel-level features are used for classification. In the real world, data is rarely so uniform and simple pixels
will not be suitable: this has led to a large literature on feature extraction methods for image data (see Feature
Engineering).
机器学习中⼀个有趣和受⼈瞩⽬的领域是图像,前⾯我们已经举了⼀些这⽅⾯的例⼦,基本都是使⽤像素级的特征来进⾏分类。在真实世
界中,数据很少能够如此规则,简单的像素也是不可能的:由此出现了⼤量关于图像数据的特征提取⽅法的⽂献(参⻅特征⼯程)。
In this section, we will take a look at one such feature extraction technique, the Histogram of Oriented Gradients (HOG),
which transforms image pixels into a vector representation that is sensitive to broadly informative image features
regardless of confounding factors like illumination. We will use these features to develop a simple face detection pipeline,
using machine learning algorithms and concepts we've seen throughout this chapter.
在本节中我们会看⼀下这种特征提取技巧,称为⽅向梯度直⽅图(HOG),它将图像像素转换成⼀个向量表⽰,这个向量表⽰形式对于⼴
泛的图像特征信息敏感,⽽不会收到混杂信息例如明亮度的影响。我们将会使⽤这些特征开发⼀个简单的脸部检测管道,当中使⽤了本章
中学习到的机器学习算法和概念。
We begin with the standard imports:
导⼊包:
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import numpy as np
HOG Features
特征
HOG
The Histogram of Gradients is a straightforward feature extraction procedure that was developed in the context of
identifying pedestrians within images. HOG involves the following steps:
1. Optionally pre-normalize images. This leads to features that resist dependence on variations in illumination.
2. the image with two filters that are sensitive to horizontal and vertical brightness gradients. These capture edge,
contour, and texture information.
3. Subdivide the image into cells of a predetermined size, and compute a histogram of the gradient orientations within
each cell.
4. Normalize the histograms in each cell by comparing to the block of neighboring cells. This further suppresses the
effect of illumination across the image.
5. Construct a one-dimensional feature vector from the information in each cell.
梯度直⽅图是⼀个很直接的特征提取⽅法,它是从在图像中识别⾏⼈的应⽤中发展起来的。HOG包括下⾯的步骤:
1. 可选的预标准化图像。这能将图像特征表现出来,并且不依赖于图像亮度的变化。
2. 使⽤两个过滤器对图像进⾏卷积处理,这两个过滤器分别对于⽔平和垂直亮度梯度敏感。本步骤能捕获边缘、轮廓和纹理信息。
3. 将图像切分成预定⼤⼩的块,然后计算每块图像的⽅向梯度直⽅图。
4. 在每块图像中通过和邻近块的⽐较对直⽅图进⾏标准化,这能进⼀步削弱图像明亮度的影响。
5. 在每块图像中构建⼀个⼀维特征向量。
A fast HOG extractor is built into the Scikit-Image project, and we can try it out relatively quickly and visualize the oriented
gradients within each cell:
Scikit-Learn
內建了⼀个快速的HOG提取器,我们可以使⽤它快速的建⽴每块图像的⽅向梯度并将其可视化:
In [2]: from skimage import data, color, feature
import skimage.data
image = color.rgb2gray(data.chelsea())
hog_vec, hog_vis = feature.hog(image, visualize=True)
fig, ax = plt.subplots(1, 2, figsize=(12, 6),
subplot_kw=dict(xticks=[], yticks=[]))
ax[0].imshow(image, cmap='gray')
ax[0].set_title('input image')
ax[1].imshow(hog_vis)
ax[1].set_title('visualization of HOG features');
HOG in Action: A Simple Face Detector
实战:⼀个简单的脸部检测器
HOG
Using these HOG features, we can build up a simple facial detection algorithm with any Scikit-Learn estimator; here we
will use a linear support vector machine (refer back to In-Depth: Support Vector Machines if you need a refresher on this).
The steps are as follows:
1. Obtain a set of image thumbnails of faces to constitute "positive" training samples.
2. Obtain a set of image thumbnails of non-faces to constitute "negative" training samples.
3. Extract HOG features from these training samples.
4. Train a linear SVM classifier on these samples.
5. For an "unknown" image, pass a sliding window across the image, using the model to evaluate whether that window
contains a face or not.
6. If detections overlap, combine them into a single window.
使⽤这些HOG特征,应⽤Scikit-Learn评估器能够构建⼀个简单的脸部检测算法;下⾯我们会使⽤线性⽀持向量机(复习参⻅深⼊:⽀持向
量机)。算法的步骤如下:
1. 获取⼀系列脸部缩略图像来组成“正”训练样本。
2. 获取⼀系列不含脸部的缩略图想来组成“负”训练样本。
3. 从这些训练样本中提取HOG特征。
4. 在这些样本上训练⼀个线性SVM分类器。
5. 对于⼀张“未知”图像,使⽤⼀个滑动窗⼝扫描图像,使⽤模型来计算该窗⼝中是否包含脸部。
6. 如果检测到重叠,将它们合并成单个窗⼝。
Let's go through these steps and try it out:
下⾯我们⼀步⼀步来实现算法并测试效果:
1. Obtain a set of positive training samples
1.
获得⼀系列正训练样本
Let's start by finding some positive training samples that show a variety of faces. We have one easy set of data to work
with—the Labeled Faces in the Wild dataset, which can be downloaded by Scikit-Learn:
让我们⾸先找⼀些正训练样本内含不同的脸部图像。我们之前使⽤的Wild数据集中就有这样的带标签图像,可以使⽤Scikit-Learn进⾏下
载:
In [3]: from sklearn.datasets import fetch_lfw_people
faces = fetch_lfw_people()
positive_patches = faces.images
positive_patches.shape
Out[3]: (13233, 62, 47)
This gives us a sample of 13,000 face images to use for training.
这能提供我们13000张脸部图像来进⾏训练。
2. Obtain a set of negative training samples
2.
获得⼀系列负训练样本
Next we need a set of similarly sized thumbnails which do not have a face in them. One way to do this is to take any
corpus of input images, and extract thumbnails from them at a variety of scales. Here we can use some of the images
shipped with Scikit-Image, along with Scikit-Learn's PatchExtractor :
然后我们需要⼀系列的相似⼤⼩的缩略图其中不含有任何脸部图像。⼀个简单的办法是使⽤任何的输⼊图像集,然后从这些不同⼤⼩的图
像中提取缩略图。下⾯我们会使⽤Scikit-Learn⾃带的⼀些图像,然后通过Scikit-Learn的 PatchExtractor 来提取缩略图:
In [4]: from skimage import data, transform
imgs_to_use = ['camera', 'text', 'coins', 'moon',
'page', 'clock', 'immunohistochemistry',
'chelsea', 'coffee', 'hubble_deep_field']
images = [color.rgb2gray(getattr(data, name)())
for name in imgs_to_use]
In [5]: from sklearn.feature_extraction.image import PatchExtractor
def extract_patches(img, N, scale=1.0, patch_size=positive_patches[0].shape):
extracted_patch_size = tuple((scale * np.array(patch_size)).astype(int))
extractor = PatchExtractor(patch_size=extracted_patch_size,
max_patches=N, random_state=0)
patches = extractor.transform(img[np.newaxis])
if scale != 1:
patches = np.array([transform.resize(patch, patch_size)
for patch in patches])
return patches
negative_patches = np.vstack([extract_patches(im, 1000, scale)
for im in images for scale in [0.5, 1.0, 2.0]])
negative_patches.shape
Out[5]: (30000, 62, 47)
We now have 30,000 suitable image patches which do not contain faces. Let's take a look at a few of them to get an idea
of what they look like:
于是我们就有了30000张不含脸部的图像作为负训练样本。下⾯看看其中的部分负样本图像:
In [6]: fig, ax = plt.subplots(6, 10)
for i, axi in enumerate(ax.flat):
axi.imshow(negative_patches[500 * i], cmap='gray')
axi.axis('off')
Our hope is that these would sufficiently cover the space of "non-faces" that our algorithm is likely to see.
我们希望这些图像⾜够覆盖“不含脸部”的样本空间,能让我们的算法正常⼯作。
3. Combine sets and extract HOG features
3.
合并集合并提取HOG特征
Now that we have these positive samples and negative samples, we can combine them and compute HOG features. This
step takes a little while, because the HOG features involve a nontrivial computation for each image:
获得了正样本和负样本之后,我们就能将它们合并在⼀起然后计算HOG特征。这⼀步需要较⻓时间,因为提取HOG特征包含了对每张图像
进⾏复杂的计算过程:
In [7]: from itertools import chain
X_train = np.array([feature.hog(im)
for im in chain(positive_patches,
negative_patches)])
y_train = np.zeros(X_train.shape[0])
y_train[:positive_patches.shape[0]] = 1
In [8]: X_train.shape
Out[8]: (43233, 1215)
We are left with 43,000 training samples in 1,215 dimensions, and we now have our data in a form that we can feed into
Scikit-Learn!
于是我们就有了43000个训练样本,每个样本都有1215个维度,所有的数据准备⼯作完成,可以⽤它们训练模型了。
4. Training a support vector machine
4.
训练⽀持向量机模型
Next we use the tools we have been exploring in this chapter to create a classifier of thumbnail patches. For such a highdimensional binary classification task, a Linear support vector machine is a good choice. We will use Scikit-Learn's
LinearSVC , because in comparison to SVC it often has better scaling for large number of samples.
下⼀步可以使⽤本章介绍的⼯具来创建缩略图的分类器了。对于这样⼀个⾼维⼆分类的任务来说,线性⽀持向量机是很好的选择。我们会
使⽤Scikit-Learn的 LinearSVC ,因为在⼤样本数量情况下,对⽐ SVC 它具有更好的扩展性。
First, though, let's use a simple Gaussian naive Bayes to get a quick baseline:
当然⾸先我们会使⽤⾼斯朴素⻉叶斯来获得⼀个基线:
In [9]: from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
cross_val_score(GaussianNB(), X_train, y_train, cv=3)
Out[9]: array([0.95385469, 0.97349247, 0.97501908])
We see that on our training data, even a simple naive Bayes algorithm gets us upwards of 90% accuracy. Let's try the
support vector machine, with a grid search over a few choices of the C parameter:
上⾯的结果说明即使只是⼀个简单的朴素⻉叶斯算法也能达到超过90%的准确率。下⾯试验⽀持向量机,在不同的⼀些C参数上使⽤⽹格
搜索验证:
译者注:下⾯代码设置了 cv=3 以满⾜新版Scikit-Learn要求,添加 dual=False 抑制警告。
In [10]: from sklearn.svm import LinearSVC
from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(LinearSVC(dual=False), {'C': [1.0, 2.0, 4.0, 8.0]}, cv=3)
grid.fit(X_train, y_train)
grid.best_score_
Out[10]: 0.9891286748548561
In [11]: grid.best_params_
Out[11]: {'C': 1.0}
Let's take the best estimator and re-train it on the full dataset:
然后使⽤最优的评估器并重新在这个模型上进⾏训练:
In [12]: model = grid.best_estimator_
model.fit(X_train, y_train)
Out[12]: LinearSVC(C=1.0, class_weight=None, dual=False, fit_intercept=True,
intercept_scaling=1, loss='squared_hinge', max_iter=1000,
multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,
verbose=0)
5. Find faces in a new image
5.
在新图像中寻找脸部
Now that we have this model in place, let's grab a new image and see how the model does. We will use one portion of
the astronaut image for simplicity (see discussion of this in Caveats and Improvements), and run a sliding window over it
and evaluate each patch:
现在模型已经准备好了,让我们使⽤⼀张新的图像看模型⼯作的情况。简单起⻅我们会使⽤⼀张宇航员的部分照⽚图像(参⻅注意事项和
改进),然后使⽤⼀个滑动窗⼝在上⾯计算每个块是否包含脸部:
In [13]: test_image = skimage.data.astronaut()
test_image = skimage.color.rgb2gray(test_image)
test_image = skimage.transform.rescale(test_image, 0.5, multichannel=False)
test_image = test_image[:160, 40:180]
plt.imshow(test_image, cmap='gray')
plt.axis('off');
Next, let's create a window that iterates over patches of this image, and compute HOG features for each patch:
然后使⽤滑动窗⼝在图像上迭代计算,得到每块图像HOG特征:
In [14]: def sliding_window(img, patch_size=positive_patches[0].shape,
istep=2, jstep=2, scale=1.0):
Ni, Nj = (int(scale * s) for s in patch_size)
for i in range(0, img.shape[0] - Ni, istep):
for j in range(0, img.shape[1] - Ni, jstep):
patch = img[i:i + Ni, j:j + Nj]
if scale != 1:
patch = transform.resize(patch, patch_size)
yield (i, j), patch
indices, patches = zip(*sliding_window(test_image))
patches_hog = np.array([feature.hog(patch) for patch in patches])
patches_hog.shape
Out[14]: (1911, 1215)
Finally, we can take these HOG-featured patches and use our model to evaluate whether each patch contains a face:
最后,我们可以对这些HOG特征块使⽤前⾯的模型来计算⾥⾯是否包含脸部:
In [15]: labels = model.predict(patches_hog)
labels.sum()
Out[15]: 49.0
We see that out of nearly 2,000 patches, we have found 30 detections. Let's use the information we have about these
patches to show where they lie on our test image, drawing them as rectangles:
上⾯两个结果告诉我们有接近2000个HOG特征块,其中检测到了接近50个脸部特征。让我们将这些检测到的HOG块使⽤矩形绘制在测试
图像之上:
In [16]: fig, ax = plt.subplots()
ax.imshow(test_image, cmap='gray')
ax.axis('off')
Ni, Nj = positive_patches[0].shape
indices = np.array(indices)
for i, j in indices[labels == 1]:
ax.add_patch(plt.Rectangle((j, i), Nj, Ni, edgecolor='red',
alpha=0.3, lw=2, facecolor='none'))
All of the detected patches overlap and found the face in the image! Not bad for a few lines of Python.
所有检测到HOG特征块重叠在⼀起并且都是检测到了图像中的脸部。仅使⽤了⼏⾏Python代码能达到这个效果还是⾮常不错的。
Caveats and Improvements
注意事项和改进
If you dig a bit deeper into the preceding code and examples, you'll see that we still have a bit of work before we can
claim a production-ready face detector. There are several issues with what we've done, and several improvements that
could be made. In particular:
如果更加深⼊的研究⼀下前⾯代码和例⼦,你会发现如果我们需要将这个检测器投⼊⽣产使⽤,还需要更多的⼯作。我们已经完成的部分
有⼀些问题,也有⼀些改进⽅案可以对它进⾏优化。具体来说:
Our training set, especially for negative features, is not very complete
我们的训练集,特别是负特征样本,并不全⾯
The central issue is that there are many face-like textures that are not in the training set, and so our current model is very
prone to false positives. You can see this if you try out the above algorithm on the full astronaut image: the current model
leads to many false detections in other regions of the image.
这个问题的核⼼是,有很多的类似脸部特征纹理的图像并没有出现在我们的训练集中,因此我们⽬前的模型很容易产⽣假阳性结果。如果
使⽤完整的宇航员图像来试验上⾯的算法你就能发现这点:⽬前的模型会在图像其他区域产⽣很多的假阳性。
We might imagine addressing this by adding a wider variety of images to the negative training set, and this would
probably yield some improvement. Another way to address this is to use a more directed approach, such as hard
negative mining. In hard negative mining, we take a new set of images that our classifier has not seen, find all the
patches representing false positives, and explicitly add them as negative instances in the training set before re-training
the classifier.
可以期望通过增加更⼴泛的负样本训练集图像解决这个问题,这会产⽣更好的结果。另外⼀个解决⽅案更加直接,被称为“实阴性挖掘”。
在这个⽅法中我们会使⽤⼀系列分类器没⻅过的图像,找出其中模型检测出来的假阳性,然后明确的将这些图像块加⼊到负样本中重新训
练分类器。
Our current pipeline searches only at one scale
我们现在的管道仅能搜索⼀个尺度
As currently written, our algorithm will miss faces that are not approximately 62×47 pixels. This can be straightforwardly
addressed by using sliding windows of a variety of sizes, and re-sizing each patch using
skimage.transform.resize before feeding it into the model. In fact, the sliding_window() utility used here is
already built with this in mind.
如上的算法会⽆法检测到尺⼨不是⼤约62×47像素的脸部图像。这可以直接通过使⽤不同⼤⼩的滑动窗⼝,然后将需检测的图像通过
skimage.transform.resize 转换成需要的⼤⼩再提供给模型进⾏检测。实际上我们定义的 sliding_window() 函数已经做了这个
准备。
We should combine overlapped detection patches
重叠的检测区域应该被合并
For a production-ready pipeline, we would prefer not to have 30 detections of the same face, but to somehow reduce
overlapping groups of detections down to a single detection. This could be done via an unsupervised clustering approach
(MeanShift Clustering is one good candidate for this), or via a procedural approach such as non-maximum suppression,
an algorithm common in machine vision.
对于应⽤到⽣产的算法,我们肯定不希望对同⼀张脸提供50个的检测结果,⽽是应该将这些重叠区域合并成⼀个检测结果。这可以通过⽆
监督聚类⽅法实现(均值漂移MeanShift聚类在这个场景下是⼀个好选择),或者通过诸如⾮极⼤值抑制(在机器视觉领域很通⽤)之类的
过程⽅法来解决。
The pipeline should be streamlined
管道应该封装⼯程化
Once we address these issues, it would also be nice to create a more streamlined pipeline for ingesting training images
and predicting sliding-window outputs. This is where Python as a data science tool really shines: with a bit of work, we
could take our prototype code and package it with a well-designed object-oriented API that give the user the ability to use
this easily. I will leave this as a proverbial "exercise for the reader".
上⾯的管道代码还不够通⽤,如果我们能创建⼀个流畅管道来消化训练图像和预测滑动窗⼝输出的话,会更加理想。这也是Python作为数
据科学⼯具真正闪亮的地⽅:只需要⼀点额外的⼯作,我们就能将我们的原型代码封装成⼀个良好设计的⾯向对象API,为⽤⼾提供简单的
接⼝和⽅便的应⽤。作者将这个留作“给读者的练习”。
More recent advances: Deep Learning
最新进展:深度学习
Finally, I should add that HOG and other procedural feature extraction methods for images are no longer state-of-the-art
techniques. Instead, many modern object detection pipelines use variants of deep neural networks: one way to think of
neural networks is that they are an estimator which determines optimal feature extraction strategies from the data, rather
than relying on the intuition of the user. An intro to these deep neural net methods is conceptually (and computationally!)
beyond the scope of this section, although open tools like Google's TensorFlow have recently made deep learning
approaches much more accessible than they once were. As of the writing of this book, deep learning in Python is still
relatively young, and so I can't yet point to any definitive resource. That said, the list of references in the following section
should provide a useful place to start!
最后,作者要指出图像的HOG和其他过程特征提取⽅法已经不再是前沿的技术。取⽽代之的是更加现在的检测⽅式,使⽤不同的深度神经
⽹络:神经⽹络可以被认为是⼀种通过数据来确定最优特征提取的评估器,⽽不是依赖于⽤⼾的直觉。对这些深度神经⽹络⽅法的介绍超
越了本节的范围,虽然开源⼯具如⾕歌的TensorFlow最近将深度学习⽅法变得⾮常容易掌握。当编写这本书时,Python的深度学习依然⽐
较年轻,因此作者⽆法提供很确定的资源。有兴趣的话,下⼀节的参考列表应该是作为读者起步的好地⽅。
<
深⼊:核密度估计 | ⽬录 | 更多机器学习资源 >
Open in Colab
<
应⽤:脸部识别管道 | ⽬录 | 附录:⽣成图像的代码 >
Further Machine Learning Resources
更多机器学习资源
This chapter has been a quick tour of machine learning in Python, primarily using the tools within the Scikit-Learn library.
As long as the chapter is, it is still too short to cover many interesting and important algorithms, approaches, and
discussions. Here I want to suggest some resources to learn more about machine learning for those who are interested.
本章对Python中的机器学习进⾏了⼀个快速简要的介绍,主要是使⽤Scikit-Learn库中的⼯具来完成机器学习任务。尽管本章内容较⻓,但
依然不⾜以覆盖很多有趣及重要的算法、⽅法和讨论。这⾥作者会推荐⼀些额外的资源,希望对那些有兴趣了解更多机器学习知识的读者
有帮助。
Machine Learning in Python
机器学习
Python
To learn more about machine learning in Python, I'd suggest some of the following resources:
The Scikit-Learn website: The Scikit-Learn website has an impressive breadth of documentation and examples
covering some of the models discussed here, and much, much more. If you want a brief survey of the most important
and often-used machine learning algorithms, this website is a good place to start.
SciPy, PyCon, and PyData tutorial videos: Scikit-Learn and other machine learning topics are perennial favorites in
the tutorial tracks of many Python-focused conference series, in particular the PyCon, SciPy, and PyData
conferences. You can find the most recent ones via a simple web search.
Introduction to Machine Learning with Python: Written by Andreas C. Mueller and Sarah Guido, this book includes a
fuller treatment of the topics in this chapter. If you're interested in reviewing the fundamentals of Machine Learning
and pushing the Scikit-Learn toolkit to its limits, this is a great resource, written by one of the most prolific developers
on the Scikit-Learn team.
Python Machine Learning: Sebastian Raschka's book focuses less on Scikit-learn itself, and more on the breadth of
machine learning tools available in Python. In particular, there is some very useful discussion on how to scale
Python-based machine learning approaches to large and complex datasets.
要学习更多Python机器学习的内容,作者推荐下⾯的资源:
Scikit-Learn官⽹:Scikit-Learn官⽹有许多⽂档和例⼦,涵盖了⼀些我们本章介绍的内容,和更多未及阐述的知识。如果你想要获得最
重要和常⽤的机器学习算法简介,这将是旅程的起点。
Python机器学习介绍:作者Andreas C. Mueller和Sarah Guido,这本书包含了本章内容更加完整的解决⽅案。如果你想要复习机器学
习的基本概念算法以及最深⼊地掌握Scikit-Learn⼯具集的话,这本书⾮常合适,其中⼀个作者是Scikit-Learn开发团队中最主要的贡献
者之⼀。
Python机器学习:作者Sebastian Raschka,本书较少聚焦在Scikit-Learn⼯具集,⽽是更多⼴泛地介绍使⽤Python的机器学习⼯具。
确切的说,书中有着许多关于扩展Python机器学习⽅法适应⼤型复杂数据集的介绍。
General Machine Learning
通⽤机器学习
Of course, machine learning is much broader than just the Python world. There are many good resources to take your
knowledge further, and here I will highlight a few that I have found useful:
Machine Learning: Taught by Andrew Ng (Coursera), this is a very clearly-taught free online course which covers the
basics of machine learning from an algorithmic perspective. It assumes undergraduate-level understanding of
mathematics and programming, and steps through detailed considerations of some of the most important machine
learning algorithms. Homework assignments, which are algorithmically graded, have you actually implement some of
these models yourself.
Pattern Recognition and Machine Learning: Written by Christopher Bishop, this classic technical text covers the
concepts of machine learning discussed in this chapter in detail. If you plan to go further in this subject, you should
have this book on your shelf.
Machine Learning: a Probabilistic Perspective: Written by Kevin Murphy, this is an excellent graduate-level text that
explores nearly all important machine learning algorithms from a ground-up, unified probabilistic perspective.
当然,机器学习的范围远远超越了Python语⾔本⾝。有许多资源能让你学习到更⼴泛更深⼊的知识,作者这⾥推荐以下他认为⾮常有⽤的
资源:
Machine Learning:Andrew Ng授课,这是⼀个⾮常出⾊在线免费课程,涵盖了从机器学习基础到算法分析。它默认听众具有⼤学本
科的数学和编程基础,然后⼀步⼀步的对⼀些最重要的机器学习算法进⾏了详细的介绍。课程还有作业,与算法相关,让听众⾃⼰能
真正实现⼀些模型算法。
模式识别与机器学习:作者Christopher Bishop,这本经典的技术书籍详细地介绍了本章讨论的机器学习概念。如果你希望深⼊研究本
章介绍的内容,应该阅读本书。
机器学习:概率视⻆:作者Kevin Murphy,这是⼀本优秀的书籍,联系概率论知识从零开始探讨了⼏乎所有重要的机器学习算法,适
合研究⽣以上阅读。
These resources are more technical than the material presented in this book, but to really understand the fundamentals
of these methods requires a deep dive into the mathematics behind them. If you're up for the challenge and ready to
bring your data science to the next level, don't hesitate to dive-in!
这些资源与本书⽐较起来技术性更强,但是需要更加深⼊的掌握算法背后的数学知识才能真正理解这些⽅法的意义。如果你已经准备好了
挑战更⾼层次的数据科学领域,不要犹豫,去掌握它们吧。
<
应⽤:脸部识别管道 | ⽬录 | 附录:⽣成图像的代码 >
<
更多机器学习资源 | ⽬录 |
Open in Colab
Appendix: Figure Code
附录:⽣成图像的代码
Many of the figures used throughout this text are created in-place by code that appears in print. In a few cases, however,
the required code is long enough (or not immediately relevant enough) that we instead put it here for reference.
本书中的⼤多数图表已经使⽤正⽂中的代码⽣成了。然⽽在⼀些情况中,需要的代码可能很⻓(或者其不与正⽂内容直接相关),因此这
部分代码放置在附录中供参考。
In [1]: %matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
In [2]: import os
if not os.path.exists('figures'):
os.makedirs('figures')
Broadcasting
⼴播
图像所在正⽂
从
中沿⽤ 参⻅
In [3]: #
astroML
:
http://www.astroml.org/book_figures/appendix/fig_broadcast_visual.html
import numpy as np
from matplotlib import pyplot as plt
#-----------------------------------------------------------#
fig = plt.figure(figsize=(6, 4.5), facecolor='w')
ax = plt.axes([0, 0, 1, 1], xticks=[], yticks=[], frameon=False)
绘制没有边界的图表和轴
def draw_cube(ax, xy, size, depth=0.4,
edges=None, label=None, label_kwargs=None, **kwargs):
"""draw and label a cube. edges is a list of numbers between
1 and 12, specifying which of the 12 cube edges to draw"""
if edges is None:
edges = range(1, 13)
x, y = xy
if 1 in edges:
ax.plot([x, x + size],
[y + size, y + size], **kwargs)
if 2 in edges:
ax.plot([x + size, x + size],
[y, y + size], **kwargs)
if 3 in edges:
ax.plot([x, x + size],
[y, y], **kwargs)
if 4 in edges:
ax.plot([x, x],
[y, y + size], **kwargs)
if 5 in edges:
ax.plot([x, x + depth],
[y + size, y + depth + size], **kwargs)
if 6 in edges:
ax.plot([x + size, x + size + depth],
[y + size, y + depth + size], **kwargs)
if 7 in edges:
ax.plot([x + size, x + size + depth],
[y, y + depth], **kwargs)
if 8 in edges:
ax.plot([x, x + depth],
[y, y + depth], **kwargs)
if 9 in edges:
ax.plot([x + depth, x + depth + size],
[y + depth + size, y + depth + size], **kwargs)
if 10 in edges:
ax.plot([x + depth + size, x + depth + size],
[y + depth, y + depth + size], **kwargs)
if 11 in edges:
ax.plot([x + depth, x + depth + size],
[y + depth, y + depth], **kwargs)
if 12 in edges:
ax.plot([x + depth, x + depth],
[y + depth, y + depth + size], **kwargs)
if label:
if label_kwargs is None:
label_kwargs = {}
ax.text(x + 0.5 * size, y + 0.5 * size, label,
ha='center', va='center', **label_kwargs)
solid = dict(c='black', ls='-', lw=1,
label_kwargs=dict(color='k'))
dotted = dict(c='black', ls='-', lw=0.5, alpha=0.5,
label_kwargs=dict(color='gray'))
depth = 0.3
#-----------------------------------------------------------#
:
draw_cube(ax, (1, 10), 1, depth, [1, 2, 3, 4, 5, 6, 9], '0', **solid)
draw_cube(ax, (2, 10), 1, depth, [1, 2, 3, 6, 9], '1', **solid)
draw_cube(ax, (3, 10), 1, depth, [1, 2, 3, 6, 7, 9, 10], '2', **solid)
绘制上部操作 向量加标量
draw_cube(ax, (6, 10), 1, depth, [1, 2, 3, 4, 5, 6, 7, 9, 10], '5', **solid)
draw_cube(ax, (7, 10), 1, depth, [1, 2, 3, 6, 7, 9, 10, 11], '5', **dotted)
draw_cube(ax, (8, 10), 1, depth, [1, 2, 3, 6, 7, 9, 10, 11], '5', **dotted)
draw_cube(ax, (12, 10), 1, depth, [1, 2, 3, 4, 5, 6, 9], '5', **solid)
draw_cube(ax, (13, 10), 1, depth, [1, 2, 3, 6, 9], '6', **solid)
draw_cube(ax, (14, 10), 1, depth, [1, 2, 3, 6, 7, 9, 10], '7', **solid)
ax.text(5, 10.5, '+', size=12, ha='center', va='center')
ax.text(10.5, 10.5, '=', size=12, ha='center', va='center')
ax.text(1, 11.5, r'${\tt np.arange(3) + 5}$',
size=12, ha='left', va='bottom')
#-----------------------------------------------------------#
:
绘制中部操作 矩阵加向量
# 第⼀部分
draw_cube(ax, (1, 7.5), 1, depth, [1, 2, 3, 4, 5, 6, 9], '1', **solid)
draw_cube(ax, (2, 7.5), 1, depth, [1, 2, 3, 6, 9], '1', **solid)
draw_cube(ax, (3, 7.5), 1, depth, [1, 2, 3, 6, 7, 9, 10], '1', **solid)
draw_cube(ax, (1, 6.5), 1, depth, [2, 3, 4], '1', **solid)
draw_cube(ax, (2, 6.5), 1, depth, [2, 3], '1', **solid)
draw_cube(ax, (3, 6.5), 1, depth, [2, 3, 7, 10], '1', **solid)
draw_cube(ax, (1, 5.5), 1, depth, [2, 3, 4], '1', **solid)
draw_cube(ax, (2, 5.5), 1, depth, [2, 3], '1', **solid)
draw_cube(ax, (3, 5.5), 1, depth, [2, 3, 7, 10], '1', **solid)
第⼆部分
#
draw_cube(ax, (6, 7.5), 1, depth, [1, 2, 3, 4, 5, 6, 9], '0', **solid)
draw_cube(ax, (7, 7.5), 1, depth, [1, 2, 3, 6, 9], '1', **solid)
draw_cube(ax, (8, 7.5), 1, depth, [1, 2, 3, 6, 7, 9, 10], '2', **solid)
draw_cube(ax, (6, 6.5), 1, depth, range(2, 13), '0', **dotted)
draw_cube(ax, (7, 6.5), 1, depth, [2, 3, 6, 7, 9, 10, 11], '1', **dotted)
draw_cube(ax, (8, 6.5), 1, depth, [2, 3, 6, 7, 9, 10, 11], '2', **dotted)
draw_cube(ax, (6, 5.5), 1, depth, [2, 3, 4, 7, 8, 10, 11, 12], '0', **dotted)
draw_cube(ax, (7, 5.5), 1, depth, [2, 3, 7, 10, 11], '1', **dotted)
draw_cube(ax, (8, 5.5), 1, depth, [2, 3, 7, 10, 11], '2', **dotted)
第三部分
#
draw_cube(ax, (12, 7.5), 1, depth, [1, 2, 3, 4, 5, 6, 9], '1', **solid)
draw_cube(ax, (13, 7.5), 1, depth, [1, 2, 3, 6, 9], '2', **solid)
draw_cube(ax, (14, 7.5), 1, depth, [1, 2, 3, 6, 7, 9, 10], '3', **solid)
draw_cube(ax, (12, 6.5), 1, depth, [2, 3, 4], '1', **solid)
draw_cube(ax, (13, 6.5), 1, depth, [2, 3], '2', **solid)
draw_cube(ax, (14, 6.5), 1, depth, [2, 3, 7, 10], '3', **solid)
draw_cube(ax, (12, 5.5), 1, depth, [2, 3, 4], '1', **solid)
draw_cube(ax, (13, 5.5), 1, depth, [2, 3], '2', **solid)
draw_cube(ax, (14, 5.5), 1, depth, [2, 3, 7, 10], '3', **solid)
ax.text(5, 7.0, '+', size=12, ha='center', va='center')
ax.text(10.5, 7.0, '=', size=12, ha='center', va='center')
ax.text(1, 9.0, r'${\tt np.ones((3,\, 3)) + np.arange(3)}$',
size=12, ha='left', va='bottom')
#-----------------------------------------------------------#
:
绘制底部操作 向量加向量,双⼴播
# 第⼀部分
draw_cube(ax, (1, 3), 1, depth, [1, 2, 3, 4, 5, 6, 7, 9, 10], '0', **solid)
draw_cube(ax, (1, 2), 1, depth, [2, 3, 4, 7, 10], '1', **solid)
draw_cube(ax, (1, 1), 1, depth, [2, 3, 4, 7, 10], '2', **solid)
draw_cube(ax, (2, 3), 1, depth, [1, 2, 3, 6, 7, 9, 10, 11], '0', **dotted)
draw_cube(ax, (2, 2), 1, depth, [2, 3, 7, 10, 11], '1', **dotted)
draw_cube(ax, (2, 1), 1, depth, [2, 3, 7, 10, 11], '2', **dotted)
draw_cube(ax, (3, 3), 1, depth, [1, 2, 3, 6, 7, 9, 10, 11], '0', **dotted)
draw_cube(ax, (3, 2), 1, depth, [2, 3, 7, 10, 11], '1', **dotted)
draw_cube(ax, (3, 1), 1, depth, [2, 3, 7, 10, 11], '2', **dotted)
第⼆部分
#
draw_cube(ax, (6, 3), 1, depth, [1, 2, 3, 4, 5, 6, 9], '0', **solid)
draw_cube(ax, (7, 3), 1, depth, [1, 2, 3, 6, 9], '1', **solid)
draw_cube(ax, (8, 3), 1, depth, [1, 2, 3, 6, 7, 9, 10], '2', **solid)
draw_cube(ax, (6, 2), 1, depth, range(2, 13), '0', **dotted)
draw_cube(ax, (7, 2), 1, depth, [2, 3, 6, 7, 9, 10, 11], '1', **dotted)
draw_cube(ax, (8, 2), 1, depth, [2, 3, 6, 7, 9, 10, 11], '2', **dotted)
draw_cube(ax, (6, 1), 1, depth, [2, 3, 4, 7, 8, 10, 11, 12], '0', **dotted)
draw_cube(ax, (7, 1), 1, depth, [2, 3, 7, 10, 11], '1', **dotted)
draw_cube(ax, (8, 1), 1, depth, [2, 3, 7, 10, 11], '2', **dotted)
第三部分
#
draw_cube(ax, (12, 3), 1, depth, [1, 2, 3, 4, 5, 6, 9], '0', **solid)
draw_cube(ax, (13, 3), 1, depth, [1, 2, 3, 6, 9], '1', **solid)
draw_cube(ax, (14, 3), 1, depth, [1, 2, 3, 6, 7, 9, 10], '2', **solid)
draw_cube(ax, (12, 2), 1, depth, [2, 3, 4], '1', **solid)
draw_cube(ax, (13, 2), 1, depth, [2, 3], '2', **solid)
draw_cube(ax, (14, 2), 1, depth, [2, 3, 7, 10], '3', **solid)
draw_cube(ax, (12, 1), 1, depth, [2, 3, 4], '2', **solid)
draw_cube(ax, (13, 1), 1, depth, [2, 3], '3', **solid)
draw_cube(ax, (14, 1), 1, depth, [2, 3, 7, 10], '4', **solid)
ax.text(5, 2.5, '+', size=12, ha='center', va='center')
ax.text(10.5, 2.5, '=', size=12, ha='center', va='center')
ax.text(1, 4.5, r'${\tt np.arange(3).reshape((3,\, 1)) + np.arange(3)}$',
ha='left', size=12, va='bottom')
ax.set_xlim(0, 16)
ax.set_ylim(0.5, 12.5)
fig.savefig('figures/02.05-broadcasting.png')
Aggregation and Grouping
聚合与分组
Figures from the chapter on aggregation and grouping
聚合与分组⼩节的图表
Split-Apply-Combine
分组-应⽤-合并
In [4]: def draw_dataframe(df, loc=None, width=None, ax=None, linestyle=None,
textstyle=None):
loc = loc or [0, 0]
width = width or 1
x, y = loc
if ax is None:
ax = plt.gca()
ncols = len(df.columns) + 1
nrows = len(df.index) + 1
dx = dy = width / ncols
if linestyle is None:
linestyle = {'color':'black'}
if textstyle is None:
textstyle = {'size': 12}
textstyle.update({'ha':'center', 'va':'center'})
绘制垂直线
#
for i in range(ncols + 1):
plt.plot(2 * [x + i * dx], [y, y + dy * nrows], **linestyle)
绘制⽔平线
#
for i in range(nrows + 1):
plt.plot([x, x + dx * ncols], 2 * [y + i * dy], **linestyle)
创建索引标签
#
for i in range(nrows - 1):
plt.text(x + 0.5 * dx, y + (i + 0.5) * dy,
str(df.index[::-1][i]), **textstyle)
创建列标签
#
for i in range(ncols - 1):
plt.text(x + (i + 1.5) * dx, y + (nrows - 0.5) * dy,
str(df.columns[i]), style='italic', **textstyle)
添加索引标签
#
if df.index.name:
plt.text(x + 0.5 * dx, y + (nrows - 0.5) * dy,
str(df.index.name), style='italic', **textstyle)
插⼊数据
#
for i in range(nrows - 1):
for j in range(ncols - 1):
plt.text(x + (j + 1.5) * dx,
y + (i + 0.5) * dy,
str(df.values[::-1][i, j]), **textstyle)
#---------------------------------------------------------#
绘制图表
import pandas as pd
df = pd.DataFrame({'data': [1, 2, 3, 4, 5, 6]},
index=['A', 'B', 'C', 'A', 'B', 'C'])
df.index.name = 'key'
fig = plt.figure(figsize=(8, 6), facecolor='white')
ax = plt.axes([0, 0, 1, 1])
ax.axis('off')
draw_dataframe(df, [0, 0])
for y, ind in zip([3, 1, -1], 'ABC'):
split = df[df.index == ind]
draw_dataframe(split, [2, y])
sum = pd.DataFrame(split.sum()).T
sum.index = [ind]
sum.index.name = 'key'
sum.columns = ['data']
draw_dataframe(sum, [4, y + 0.25])
result = df.groupby(df.index).sum()
draw_dataframe(result, [6, 0.75])
style = dict(fontsize=14, ha='center', weight='bold')
plt.text(0.5, 3.6, "Input", **style)
plt.text(2.5, 4.6, "Split", **style)
plt.text(4.5, 4.35, "Apply (sum)", **style)
plt.text(6.5, 2.85, "Combine", **style)
arrowprops = dict(facecolor='black', width=1, headwidth=6)
plt.annotate('', (1.8, 3.6), (1.2, 2.8), arrowprops=arrowprops)
plt.annotate('', (1.8, 1.75), (1.2, 1.75), arrowprops=arrowprops)
plt.annotate('', (1.8, -0.1), (1.2, 0.7), arrowprops=arrowprops)
plt.annotate('', (3.8, 3.8), (3.2, 3.8), arrowprops=arrowprops)
plt.annotate('', (3.8, 1.75), (3.2, 1.75), arrowprops=arrowprops)
plt.annotate('', (3.8, -0.3), (3.2, -0.3), arrowprops=arrowprops)
plt.annotate('', (5.8, 2.8), (5.2, 3.6), arrowprops=arrowprops)
plt.annotate('', (5.8, 1.75), (5.2, 1.75), arrowprops=arrowprops)
plt.annotate('', (5.8, 0.7), (5.2, -0.1), arrowprops=arrowprops)
plt.axis('equal')
plt.ylim(-1.5, 5);
fig.savefig('figures/03.08-split-apply-combine.png')
What Is Machine Learning?
什么是机器学习?
后⾯正⽂需要⽤的格式化绘图函数
In [5]: #
def format_plot(ax, title):
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_xlabel('feature 1', color='gray')
ax.set_ylabel('feature 2', color='gray')
ax.set_title(title, color='gray')
Classification Example Figures
分类例⼦图表
图表所在正⽂
The following code generates the figures from the Classification section.
下⾯代码⽣成了分类⼩节的图表。
In [6]: from sklearn.datasets.samples_generator import make_blobs
from sklearn.svm import SVC
创建 个独⽴的点
#
50
X, y = make_blobs(n_samples=50, centers=2,
random_state=0, cluster_std=0.60)
拟合⽀持向量机模型
#
clf = SVC(kernel='linear')
clf.fit(X, y)
创建新的点⽤来预测
#
X2, _ = make_blobs(n_samples=80, centers=2,
random_state=0, cluster_std=0.80)
X2 = X2[50:]
预测新数据点的标签
#
y2 = clf.predict(X2)
Classification Example Figure 1
分类例⼦图1
In [7]: # 绘制数据点
fig, ax = plt.subplots(figsize=(8, 6))
point_style = dict(cmap='Paired', s=50)
ax.scatter(X[:, 0], X[:, 1], c=y, **point_style)
格式化图表
#
format_plot(ax, 'Input Data')
ax.axis([-1, 4, -2, 7])
fig.savefig('figures/05.01-classification-1.png')
Classification Example Figure 2
分类例⼦图2
In [8]: # 获得描述模型的轮廓
xx = np.linspace(-1, 4, 10)
yy = np.linspace(-2, 7, 10)
xy1, xy2 = np.meshgrid(xx, yy)
Z = np.array([clf.decision_function([t])
for t in zip(xy1.flat, xy2.flat)]).reshape(xy1.shape)
绘制点和模型
#
fig, ax = plt.subplots(figsize=(8, 6))
line_style = dict(levels = [-1.0, 0.0, 1.0],
linestyles = ['dashed', 'solid', 'dashed'],
colors = 'gray', linewidths=1)
ax.scatter(X[:, 0], X[:, 1], c=y, **point_style)
ax.contour(xy1, xy2, Z, **line_style)
格式化图表
#
format_plot(ax, 'Model Learned from Input Data')
ax.axis([-1, 4, -2, 7])
fig.savefig('figures/05.01-classification-2.png')
Classification Example Figure 3
分类例⼦图3
In [9]: # 绘制结果
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
ax[0].scatter(X2[:, 0], X2[:, 1], c='gray', **point_style)
ax[0].axis([-1, 4, -2, 7])
ax[1].scatter(X2[:, 0], X2[:, 1], c=y2, **point_style)
ax[1].contour(xy1, xy2, Z, **line_style)
ax[1].axis([-1, 4, -2, 7])
format_plot(ax[0], 'Unknown Data')
format_plot(ax[1], 'Predicted Labels')
fig.savefig('figures/05.01-classification-3.png')
Regression Example Figures
回归例⼦图表
图表所在正⽂
The following code generates the figures from the regression section.
下⾯的代码⽣成回归⼩节的图表。
In [10]: from sklearn.linear_model import LinearRegression
创建数据点⽤于回归
#
rng = np.random.RandomState(1)
X = rng.randn(200, 2)
y = np.dot(X, [-2, 1]) + 0.1 * rng.randn(X.shape[0])
拟合回归模型
#
model = LinearRegression()
model.fit(X, y)
创建新数据点进⾏预测
#
X2 = rng.randn(100, 2)
预测标签
#
y2 = model.predict(X2)
Regression Example Figure 1
回归例⼦图1
In [11]: # 绘制数据点
fig, ax = plt.subplots()
points = ax.scatter(X[:, 0], X[:, 1], c=y, s=50,
cmap='viridis')
格式化图表
#
format_plot(ax, 'Input Data')
ax.axis([-4, 4, -3, 3])
fig.savefig('figures/05.01-regression-1.png')
Regression Example Figure 2
回归例⼦图2
In [12]: from mpl_toolkits.mplot3d.art3d import Line3DCollection
points = np.hstack([X, y[:, None]]).reshape(-1, 1, 3)
segments = np.hstack([points, points])
segments[:, 0, 2] = -8
在 中绘制点
#
3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], y, c=y, s=35,
cmap='viridis')
ax.add_collection3d(Line3DCollection(segments, colors='gray', alpha=0.2))
ax.scatter(X[:, 0], X[:, 1], -8 + np.zeros(X.shape[0]), c=y, s=10,
cmap='viridis')
格式化图表
#
ax.patch.set_facecolor('white')
ax.view_init(elev=20, azim=-70)
ax.set_zlim3d(-8, 8)
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.zaxis.set_major_formatter(plt.NullFormatter())
ax.set(xlabel='feature 1', ylabel='feature 2', zlabel='label')
隐藏坐标轴 是否有更好的办法
#
(
?)
ax.w_xaxis.line.set_visible(False)
ax.w_yaxis.line.set_visible(False)
ax.w_zaxis.line.set_visible(False)
for tick in ax.w_xaxis.get_ticklines():
tick.set_visible(False)
for tick in ax.w_yaxis.get_ticklines():
tick.set_visible(False)
for tick in ax.w_zaxis.get_ticklines():
tick.set_visible(False)
fig.savefig('figures/05.01-regression-2.png')
Regression Example Figure 3
回归例⼦图3
In [13]: from matplotlib.collections import LineCollection
绘制数据点
#
fig, ax = plt.subplots()
pts = ax.scatter(X[:, 0], X[:, 1], c=y, s=50,
cmap='viridis', zorder=2)
计算和绘制模型颜⾊⽹格
#
xx, yy = np.meshgrid(np.linspace(-4, 4),
np.linspace(-3, 3))
Xfit = np.vstack([xx.ravel(), yy.ravel()]).T
yfit = model.predict(Xfit)
zz = yfit.reshape(xx.shape)
ax.pcolorfast([-4, 4], [-3, 3], zz, alpha=0.5,
cmap='viridis', norm=pts.norm, zorder=1)
格式化图表
#
format_plot(ax, 'Input Data with Linear Fit')
ax.axis([-4, 4, -3, 3])
fig.savefig('figures/05.01-regression-3.png')
Regression Example Figure 4
回归例⼦图4
In [14]: # 绘制模型拟合
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
ax[0].scatter(X2[:, 0], X2[:, 1], c='gray', s=50)
ax[0].axis([-4, 4, -3, 3])
ax[1].scatter(X2[:, 0], X2[:, 1], c=y2, s=50,
cmap='viridis', norm=pts.norm)
ax[1].axis([-4, 4, -3, 3])
格式化图表
#
format_plot(ax[0], 'Unknown Data')
format_plot(ax[1], 'Predicted Labels')
fig.savefig('figures/05.01-regression-4.png')
Clustering Example Figures
聚类例⼦图表
图表所在正⽂
The following code generates the figures from the clustering section.
下⾯代码⽣成聚类⼩节的图表。
In [15]: from sklearn.datasets.samples_generator import make_blobs
from sklearn.cluster import KMeans
创建 个独⽴的点
#
50
X, y = make_blobs(n_samples=100, centers=4,
random_state=42, cluster_std=1.5)
拟合 均值模型
#
k
model = KMeans(4, random_state=0)
y = model.fit_predict(X)
Clustering Example Figure 1
聚类例⼦图1
绘制输⼊数据
In [16]: #
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(X[:, 0], X[:, 1], s=50, color='gray')
格式化图表
#
format_plot(ax, 'Input Data')
fig.savefig('figures/05.01-clustering-1.png')
Clustering Example Figure 2
聚类例⼦图2
In [17]: # 使⽤聚类标签绘制数据点
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(X[:, 0], X[:, 1], s=50, c=y, cmap='viridis')
格式化图表
#
format_plot(ax, 'Learned Cluster Labels')
fig.savefig('figures/05.01-clustering-2.png')
Dimensionality Reduction Example Figures
降维例⼦图表
图表所在正⽂
The following code generates the figures from the dimensionality reduction section.
下⾯代码⽣成降维⼩节图表。
Dimensionality Reduction Example Figure 1
降维例⼦图1
In [18]: from sklearn.datasets import make_swiss_roll
构造数据
#
X, y = make_swiss_roll(200, noise=0.5, random_state=42)
X = X[:, [0, 2]]
可视化数据
#
fig, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], color='gray', s=30)
格式化图表
#
format_plot(ax, 'Input Data')
fig.savefig('figures/05.01-dimesionality-1.png')
Dimensionality Reduction Example Figure 2
降维例⼦图2
In [19]: from sklearn.manifold import Isomap
model = Isomap(n_neighbors=8, n_components=1)
y_fit = model.fit_transform(X).ravel()
可视化数据
#
fig, ax = plt.subplots()
pts = ax.scatter(X[:, 0], X[:, 1], c=y_fit, cmap='viridis', s=30)
cb = fig.colorbar(pts, ax=ax)
格式化图表
#
format_plot(ax, 'Learned Latent Parameter')
cb.set_ticks([])
cb.set_label('Latent Variable', color='gray')
fig.savefig('figures/05.01-dimesionality-2.png')
Introducing Scikit-Learn
介绍
Scikit-Learn
Features and Labels Grid
特征和标签⽹格
The following is the code generating the diagram showing the features matrix and target array.
下⾯代码⽣成特征矩阵和⽬标数组部分的图表。
In [20]: fig = plt.figure(figsize=(6, 4))
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.axis('equal')
绘制特征矩阵
#
ax.vlines(range(6), ymin=0, ymax=9, lw=1)
ax.hlines(range(10), xmin=0, xmax=5, lw=1)
font_prop = dict(size=12, family='monospace')
ax.text(-1, -1, "Feature Matrix ($X$)", size=14)
ax.text(0.1, -0.3, r'n_features $\longrightarrow$', **font_prop)
ax.text(-0.1, 0.1, r'$\longleftarrow$ n_samples', rotation=90,
va='top', ha='right', **font_prop)
绘制标签向量
#
ax.vlines(range(8, 10), ymin=0, ymax=9, lw=1)
ax.hlines(range(10), xmin=8, xmax=9, lw=1)
ax.text(7, -1, "Target Vector ($y$)", size=14)
ax.text(7.9, 0.1, r'$\longleftarrow$ n_samples', rotation=90,
va='top', ha='right', **font_prop)
ax.set_ylim(10, -2)
fig.savefig('figures/05.02-samples-features.png')
Hyperparameters and Model Validation
超参数和模型验证
Cross-Validation Figures
交叉验证图表
In [21]: def draw_rects(N, ax, textprop={}):
for i in range(N):
ax.add_patch(plt.Rectangle((0, i), 5, 0.7, fc='white'))
ax.add_patch(plt.Rectangle((5. * i / N, i), 5. / N, 0.7, fc='lightgray'))
ax.text(5. * (i + 0.5) / N, i + 0.35,
"validation\nset", ha='center', va='center', **textprop)
ax.text(0, i + 0.35, "trial {0}".format(N - i),
ha='right', va='center', rotation=90, **textprop)
ax.set_xlim(-1, 6)
ax.set_ylim(-0.2, N + 0.2)
2-Fold Cross-Validation
2-
折叠交叉验证
In [22]: fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
draw_rects(2, ax, textprop=dict(size=14))
fig.savefig('figures/05.03-2-fold-CV.png')
5-Fold Cross-Validation
5-
折叠交叉验证
In [23]: fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
draw_rects(5, ax, textprop=dict(size=10))
fig.savefig('figures/05.03-5-fold-CV.png')
Overfitting and Underfitting
过拟合与⽋拟合
In [24]: import numpy as np
def make_data(N=30, err=0.8, rseed=1):
#
rng = np.random.RandomState(rseed)
X = rng.rand(N, 1) ** 2
y = 10 - 1. / (X.ravel() + 0.1)
if err > 0:
y += err * rng.randn(N)
return X, y
随机产⽣数据样本
In [25]: from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
def PolynomialRegression(degree=2, **kwargs):
return make_pipeline(PolynomialFeatures(degree),
LinearRegression(**kwargs))
Bias-Variance Tradeoff
偏差⽅差权衡
In [26]: X, y = make_data()
xfit = np.linspace(-0.1, 1.0, 1000)[:, None]
model1 = PolynomialRegression(1).fit(X, y)
model20 = PolynomialRegression(20).fit(X, y)
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
ax[0].scatter(X.ravel(), y, s=40)
ax[0].plot(xfit.ravel(), model1.predict(xfit), color='gray')
ax[0].axis([-0.1, 1.0, -2, 14])
ax[0].set_title('High-bias model: Underfits the data', size=14)
ax[1].scatter(X.ravel(), y, s=40)
ax[1].plot(xfit.ravel(), model20.predict(xfit), color='gray')
ax[1].axis([-0.1, 1.0, -2, 14])
ax[1].set_title('High-variance model: Overfits the data', size=14)
fig.savefig('figures/05.03-bias-variance.png')
Bias-Variance Tradeoff Metrics
偏差⽅差权衡指标
In [27]: fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
X2, y2 = make_data(10, rseed=42)
ax[0].scatter(X.ravel(), y, s=40, c='blue')
ax[0].plot(xfit.ravel(), model1.predict(xfit), color='gray')
ax[0].axis([-0.1, 1.0, -2, 14])
ax[0].set_title('High-bias model: Underfits the data', size=14)
ax[0].scatter(X2.ravel(), y2, s=40, c='red')
ax[0].text(0.02, 0.98, "training score: $R^2$ = {0:.2f}".format(model1.score(X, y)),
ha='left', va='top', transform=ax[0].transAxes, size=14, color='blue')
ax[0].text(0.02, 0.91, "validation score: $R^2$ = {0:.2f}".format(model1.score(X2, y2)),
ha='left', va='top', transform=ax[0].transAxes, size=14, color='red')
ax[1].scatter(X.ravel(), y, s=40, c='blue')
ax[1].plot(xfit.ravel(), model20.predict(xfit), color='gray')
ax[1].axis([-0.1, 1.0, -2, 14])
ax[1].set_title('High-variance model: Overfits the data', size=14)
ax[1].scatter(X2.ravel(), y2, s=40, c='red')
ax[1].text(0.02, 0.98, "training score: $R^2$ = {0:.2g}".format(model20.score(X, y)),
ha='left', va='top', transform=ax[1].transAxes, size=14, color='blue')
ax[1].text(0.02, 0.91, "validation score: $R^2$ = {0:.2g}".format(model20.score(X2, y2)),
ha='left', va='top', transform=ax[1].transAxes, size=14, color='red')
fig.savefig('figures/05.03-bias-variance-2.png')
Validation Curve
验证曲线
In [28]: x = np.linspace(0, 1, 1000)
y1 = -(x - 0.5) ** 2
y2 = y1 - 0.33 + np.exp(x - 1)
fig, ax = plt.subplots()
ax.plot(x, y2, lw=10, alpha=0.5, color='blue')
ax.plot(x, y1, lw=10, alpha=0.5, color='red')
ax.text(0.15, 0.2, "training score", rotation=45, size=16, color='blue')
ax.text(0.2, -0.05, "validation score", rotation=20, size=16, color='red')
ax.text(0.02, 0.1, r'$\longleftarrow$ High Bias', size=18, rotation=90, va='center')
ax.text(0.98, 0.1, r'$\longleftarrow$ High Variance $\longrightarrow$', size=18, rotation=90, ha='ri
ght', va='center')
ax.text(0.48, -0.12, 'Best$\\longrightarrow$\nModel', size=18, rotation=90, va='center')
ax.set_xlim(0, 1)
ax.set_ylim(-0.3, 0.5)
ax.set_xlabel(r'model complexity $\longrightarrow$', size=14)
ax.set_ylabel(r'model score $\longrightarrow$', size=14)
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_title("Validation Curve Schematic", size=16)
fig.savefig('figures/05.03-validation-curve.png')
Learning Curve
学习曲线
In [29]: N = np.linspace(0, 1, 1000)
y1 = 0.75 + 0.2 * np.exp(-4 * N)
y2 = 0.7 - 0.6 * np.exp(-4 * N)
fig, ax = plt.subplots()
ax.plot(x, y1, lw=10, alpha=0.5, color='blue')
ax.plot(x, y2, lw=10, alpha=0.5, color='red')
ax.text(0.2, 0.88, "training score", rotation=-10, size=16, color='blue')
ax.text(0.2, 0.5, "validation score", rotation=30, size=16, color='red')
ax.text(0.98, 0.45, r'Good Fit $\longrightarrow$', size=18, rotation=90, ha='right', va='center')
ax.text(0.02, 0.57, r'$\longleftarrow$ High Variance $\longrightarrow$', size=18, rotation=90, va='c
enter')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel(r'training set size $\longrightarrow$', size=14)
ax.set_ylabel(r'model score $\longrightarrow$', size=14)
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_title("Learning Curve Schematic", size=16)
fig.savefig('figures/05.03-learning-curve.png')
Gaussian Naive Bayes
⾼斯朴素⻉叶斯
Gaussian Naive Bayes Example
⾼斯朴素⻉叶斯例⼦
图表所在正⽂
In [30]: from sklearn.datasets import make_blobs
X, y = make_blobs(100, 2, centers=2, random_state=2, cluster_std=1.5)
fig, ax = plt.subplots()
ax.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='RdBu')
ax.set_title('Naive Bayes Model', size=14)
xlim = (-8, 8)
ylim = (-15, 5)
xg = np.linspace(xlim[0], xlim[1], 60)
yg = np.linspace(ylim[0], ylim[1], 40)
xx, yy = np.meshgrid(xg, yg)
Xgrid = np.vstack([xx.ravel(), yy.ravel()]).T
for label, color in enumerate(['red', 'blue']):
mask = (y == label)
mu, std = X[mask].mean(0), X[mask].std(0)
P = np.exp(-0.5 * (Xgrid - mu) ** 2 / std ** 2).prod(1)
Pm = np.ma.masked_array(P, P < 0.03)
ax.pcolorfast(xg, yg, Pm.reshape(xx.shape), alpha=0.5,
cmap=color.title() + 's')
ax.contour(xx, yy, P.reshape(xx.shape),
levels=[0.01, 0.1, 0.5, 0.9],
colors=color, alpha=0.2)
ax.set(xlim=xlim, ylim=ylim)
fig.savefig('figures/05.05-gaussian-NB.png')
Linear Regression
线性回归
Gaussian Basis Functions
⾼斯基本函数
图表所在正⽂
In [31]: from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.base import BaseEstimator, TransformerMixin
class GaussianFeatures(BaseEstimator, TransformerMixin):
"""Uniformly-spaced Gaussian Features for 1D input"""
def __init__(self, N, width_factor=2.0):
self.N = N
self.width_factor = width_factor
@staticmethod
def _gauss_basis(x, y, width, axis=None):
arg = (x - y) / width
return np.exp(-0.5 * np.sum(arg ** 2, axis))
def fit(self, X, y=None):
# create N centers spread along the data range
self.centers_ = np.linspace(X.min(), X.max(), self.N)
self.width_ = self.width_factor * (self.centers_[1] - self.centers_[0])
return self
def transform(self, X):
return self._gauss_basis(X[:, :, np.newaxis], self.centers_,
self.width_, axis=1)
rng = np.random.RandomState(1)
x = 10 * rng.rand(50)
y = np.sin(x) + 0.1 * rng.randn(50)
xfit = np.linspace(0, 10, 1000)
gauss_model = make_pipeline(GaussianFeatures(10, 1.0),
LinearRegression())
gauss_model.fit(x[:, np.newaxis], y)
yfit = gauss_model.predict(xfit[:, np.newaxis])
gf = gauss_model.named_steps['gaussianfeatures']
lm = gauss_model.named_steps['linearregression']
fig, ax = plt.subplots()
for i in range(10):
selector = np.zeros(10)
selector[i] = 1
Xfit = gf.transform(xfit[:, None]) * selector
yfit = lm.predict(Xfit)
ax.fill_between(xfit, yfit.min(), yfit, color='gray', alpha=0.2)
ax.scatter(x, y)
ax.plot(xfit, gauss_model.predict(xfit[:, np.newaxis]))
ax.set_xlim(0, 10)
ax.set_ylim(yfit.min(), 1.5)
fig.savefig('figures/05.06-gaussian-basis.png')
Random Forests
随机森林
Helper Code
⼯具代码
The following will create a module helpers_05_08.py which contains some tools used in In-Depth: Decision Trees
and Random Forests.
下⾯代码创建模块 helpers_05_08.py ,包含⼀些在深⼊:决策树和随机森林中⽤到的⼯具。
In [32]: %%file helpers_05_08.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from ipywidgets import interact
def visualize_tree(estimator, X, y, boundaries=True,
xlim=None, ylim=None, ax=None):
ax = ax or plt.gca()
绘制训练数据点
#
ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap='viridis',
clim=(y.min(), y.max()), zorder=3)
ax.axis('tight')
ax.axis('off')
if xlim is None:
xlim = ax.get_xlim()
if ylim is None:
ylim = ax.get_ylim()
拟合评估器
#
estimator.fit(X, y)
xx, yy = np.meshgrid(np.linspace(*xlim, num=200),
np.linspace(*ylim, num=200))
Z = estimator.predict(np.c_[xx.ravel(), yy.ravel()])
将结果放到颜⾊图表中
#
n_classes = len(np.unique(y))
Z = Z.reshape(xx.shape)
contours = ax.contourf(xx, yy, Z, alpha=0.3,
levels=np.arange(n_classes + 1) - 0.5,
cmap='viridis', clim=(y.min(), y.max()),
zorder=1)
ax.set(xlim=xlim, ylim=ylim)
绘制决策树边界
#
def plot_boundaries(i, xlim, ylim):
if i >= 0:
tree = estimator.tree_
if tree.feature[i] == 0:
ax.plot([tree.threshold[i], tree.threshold[i]], ylim, '-k', zorder=2)
plot_boundaries(tree.children_left[i],
[xlim[0], tree.threshold[i]], ylim)
plot_boundaries(tree.children_right[i],
[tree.threshold[i], xlim[1]], ylim)
elif tree.feature[i] == 1:
ax.plot(xlim, [tree.threshold[i], tree.threshold[i]], '-k', zorder=2)
plot_boundaries(tree.children_left[i], xlim,
[ylim[0], tree.threshold[i]])
plot_boundaries(tree.children_right[i], xlim,
[tree.threshold[i], ylim[1]])
if boundaries:
plot_boundaries(0, xlim, ylim)
def plot_tree_interactive(X, y):
def interactive_tree(depth=5):
clf = DecisionTreeClassifier(max_depth=depth, random_state=0)
visualize_tree(clf, X, y)
return interact(interactive_tree, depth=[1, 5])
def randomized_tree_interactive(X, y):
N = int(0.75 * X.shape[0])
xlim = (X[:, 0].min(), X[:, 0].max())
ylim = (X[:, 1].min(), X[:, 1].max())
def fit_randomized_tree(random_state=0):
clf = DecisionTreeClassifier(max_depth=15)
i = np.arange(len(y))
rng = np.random.RandomState(random_state)
rng.shuffle(i)
visualize_tree(clf, X[i[:N]], y[i[:N]], boundaries=False,
xlim=xlim, ylim=ylim)
interact(fit_randomized_tree, random_state=[0, 100]);
Overwriting helpers_05_08.py
Decision Tree Example
决策树例⼦
In [33]: fig = plt.figure(figsize=(10, 4))
ax = fig.add_axes([0, 0, 0.8, 1], frameon=False, xticks=[], yticks=[])
ax.set_title('Example Decision Tree: Animal Classification', size=24)
def text(ax, x, y, t, size=20, **kwargs):
ax.text(x, y, t,
ha='center', va='center', size=size,
bbox=dict(boxstyle='round', ec='k', fc='w'), **kwargs)
text(ax, 0.5, 0.9, "How big is\nthe animal?", 20)
text(ax, 0.3, 0.6, "Does the animal\nhave horns?", 18)
text(ax, 0.7, 0.6, "Does the animal\nhave two legs?", 18)
text(ax, 0.12, 0.3, "Are the horns\nlonger than 10cm?", 14)
text(ax, 0.38, 0.3, "Is the animal\nwearing a collar?", 14)
text(ax, 0.62, 0.3, "Does the animal\nhave wings?", 14)
text(ax, 0.88, 0.3, "Does the animal\nhave a tail?", 14)
text(ax, 0.4, 0.75, "> 1m", 12, alpha=0.4)
text(ax, 0.6, 0.75, "< 1m", 12, alpha=0.4)
text(ax, 0.21, 0.45, "yes", 12, alpha=0.4)
text(ax, 0.34, 0.45, "no", 12, alpha=0.4)
text(ax, 0.66, 0.45, "yes", 12, alpha=0.4)
text(ax, 0.79, 0.45, "no", 12, alpha=0.4)
ax.plot([0.3, 0.5, 0.7], [0.6, 0.9, 0.6], '-k')
ax.plot([0.12, 0.3, 0.38], [0.3, 0.6, 0.3], '-k')
ax.plot([0.62, 0.7, 0.88], [0.3, 0.6, 0.3], '-k')
ax.plot([0.0, 0.12, 0.20], [0.0, 0.3, 0.0], '--k')
ax.plot([0.28, 0.38, 0.48], [0.0, 0.3, 0.0], '--k')
ax.plot([0.52, 0.62, 0.72], [0.0, 0.3, 0.0], '--k')
ax.plot([0.8, 0.88, 1.0], [0.0, 0.3, 0.0], '--k')
ax.axis([0, 1, 0, 1])
fig.savefig('figures/05.08-decision-tree.png')
Decision Tree Levels
决策树层次
In [34]: from helpers_05_08 import visualize_tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_blobs
fig, ax = plt.subplots(1, 4, figsize=(16, 3))
fig.subplots_adjust(left=0.02, right=0.98, wspace=0.1)
X, y = make_blobs(n_samples=300, centers=4,
random_state=0, cluster_std=1.0)
for axi, depth in zip(ax, range(1, 5)):
model = DecisionTreeClassifier(max_depth=depth)
visualize_tree(model, X, y, ax=axi)
axi.set_title('depth = {0}'.format(depth))
fig.savefig('figures/05.08-decision-tree-levels.png')
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
Decision Tree Overfitting
决策树过拟合
In [35]: model = DecisionTreeClassifier()
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
visualize_tree(model, X[::2], y[::2], boundaries=False, ax=ax[0])
visualize_tree(model, X[1::2], y[1::2], boundaries=False, ax=ax[1])
fig.savefig('figures/05.08-decision-tree-overfitting.png')
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
/home/wangy/git/wangyingsm/Python-Data-Science-Handbook/notebooks/helpers_05_08.py:34: UserWarning: T
he following kwargs were not used by contour: 'clim'
zorder=1)
Principal Component Analysis
主成分分析
Principal Components Rotation
主成分旋转
In [36]: from sklearn.decomposition import PCA
In [37]: def draw_vector(v0, v1, ax=None):
ax = ax or plt.gca()
arrowprops=dict(arrowstyle='->',
linewidth=2,
shrinkA=0, shrinkB=0)
ax.annotate('', v1, v0, arrowprops=arrowprops)
In [38]: rng = np.random.RandomState(1)
X = np.dot(rng.rand(2, 2), rng.randn(2, 200)).T
pca = PCA(n_components=2, whiten=True)
pca.fit(X)
fig, ax = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.0625, right=0.95, wspace=0.1)
绘制数据
#
ax[0].scatter(X[:, 0], X[:, 1], alpha=0.2)
for length, vector in zip(pca.explained_variance_, pca.components_):
v = vector * 3 * np.sqrt(length)
draw_vector(pca.mean_, pca.mean_ + v, ax=ax[0])
ax[0].axis('equal');
ax[0].set(xlabel='x', ylabel='y', title='input')
绘制主成分
#
X_pca = pca.transform(X)
ax[1].scatter(X_pca[:, 0], X_pca[:, 1], alpha=0.2)
draw_vector([0, 0], [0, 3], ax=ax[1])
draw_vector([0, 0], [3, 0], ax=ax[1])
ax[1].axis('equal')
ax[1].set(xlabel='component 1', ylabel='component 2',
title='principal components',
xlim=(-5, 5), ylim=(-3, 3.1))
fig.savefig('figures/05.09-PCA-rotation.png')
Digits Pixel Components
⼿写数字像素成分
In [39]: def plot_pca_components(x, coefficients=None, mean=0, components=None,
imshape=(8, 8), n_components=8, fontsize=12,
show_mean=True):
if coefficients is None:
coefficients = x
if components is None:
components = np.eye(len(coefficients), len(x))
mean = np.zeros_like(x) + mean
fig = plt.figure(figsize=(1.2 * (5 + n_components), 1.2 * 2))
g = plt.GridSpec(2, 4 + bool(show_mean) + n_components, hspace=0.3)
def show(i, j, x, title=None):
ax = fig.add_subplot(g[i, j], xticks=[], yticks=[])
ax.imshow(x.reshape(imshape), interpolation='nearest')
if title:
ax.set_title(title, fontsize=fontsize)
show(slice(2), slice(2), x, "True")
approx = mean.copy()
counter = 2
if show_mean:
show(0, 2, np.zeros_like(x) + mean, r'$\mu$')
show(1, 2, approx, r'$1 \cdot \mu$')
counter += 1
for i in range(n_components):
approx = approx + coefficients[i] * components[i]
show(0, i + counter, components[i], r'$c_{0}$'.format(i + 1))
show(1, i + counter, approx,
r"${0:.2f} \cdot c_{1}$".format(coefficients[i], i + 1))
if show_mean or i > 0:
plt.gca().text(0, 1.05, '$+$', ha='right', va='bottom',
transform=plt.gca().transAxes, fontsize=fontsize)
show(slice(2), slice(-2, None), approx, "Approx")
return fig
In [40]: from sklearn.datasets import load_digits
digits = load_digits()
sns.set_style('white')
fig = plot_pca_components(digits.data[10],
show_mean=False)
fig.savefig('figures/05.09-digits-pixel-components.png')
Digits PCA Components
⼿写数字PCA成分
In [41]: pca = PCA(n_components=8)
Xproj = pca.fit_transform(digits.data)
sns.set_style('white')
fig = plot_pca_components(digits.data[10], Xproj[10],
pca.mean_, pca.components_)
fig.savefig('figures/05.09-digits-pca-components.png')
Manifold Learning
流形学习
LLE vs MDS Linkages
LLE
对⽐MDS连接图
In [42]: def make_hello(N=1000, rseed=42):
#
“HELLO”
hello.png
fig, ax = plt.subplots(figsize=(4, 1))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax.axis('off')
ax.text(0.5, 0.4, 'HELLO', va='center', ha='center', weight='bold', size=85)
fig.savefig('hello.png')
plt.close(fig)
创建
图像,存储到
⽂件中
打开图像⽂件,在上⾯随机绘制数据点
#
from matplotlib.image import imread
data = imread('hello.png')[::-1, :, 0].T
rng = np.random.RandomState(rseed)
X = rng.rand(4 * N, 2)
i, j = (X * data.shape).astype(int).T
mask = (data[i, j] < 1)
X = X[mask]
X[:, 0] *= (data.shape[0] / data.shape[1])
X = X[:N]
return X[np.argsort(X[:, 0])]
In [43]: def make_hello_s_curve(X):
t = (X[:, 0] - 2) * 0.75 * np.pi
x = np.sin(t)
y = X[:, 1]
z = np.sign(t) * (np.cos(t) - 1)
return np.vstack((x, y, z)).T
X = make_hello(1000)
XS = make_hello_s_curve(X)
colorize = dict(c=X[:, 0], cmap=plt.cm.get_cmap('rainbow', 5))
In [44]: from mpl_toolkits.mplot3d.art3d import Line3DCollection
from sklearn.neighbors import NearestNeighbors
构建 的连接线
#
MDS
rng = np.random.RandomState(42)
ind = rng.permutation(len(X))
lines_MDS = [(XS[i], XS[j]) for i in ind[:100] for j in ind[100:200]]
构建 的连接线
#
LLE
nbrs = NearestNeighbors(n_neighbors=100).fit(XS).kneighbors(XS[ind[:100]])[1]
lines_LLE = [(XS[ind[i]], XS[j]) for i in range(100) for j in nbrs[i]]
titles = ['MDS Linkages', 'LLE Linkages (100 NN)']
绘制结果
#
fig, ax = plt.subplots(1, 2, figsize=(16, 6),
subplot_kw=dict(projection='3d', facecolor='none'))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0)
for axi, title, lines in zip(ax, titles, [lines_MDS, lines_LLE]):
axi.scatter3D(XS[:, 0], XS[:, 1], XS[:, 2], **colorize);
axi.add_collection(Line3DCollection(lines, lw=1, color='black',
alpha=0.05))
axi.view_init(elev=10, azim=-80)
axi.set_title(title, size=18)
fig.savefig('figures/05.10-LLE-vs-MDS.png')
K-Means
均值
K
Expectation-Maximization
期望最⼤化
图表所在正⽂
The following figure shows a visual depiction of the Expectation-Maximization approach to K Means:
下⾯图表展⽰了k均值的期望最⼤化算法的可视化说明:
In [45]: from sklearn.datasets.samples_generator import make_blobs
from sklearn.metrics import pairwise_distances_argmin
X, y_true = make_blobs(n_samples=300, centers=4,
cluster_std=0.60, random_state=0)
rng = np.random.RandomState(42)
centers = [0, 4] + rng.randn(4, 2)
def draw_points(ax, c, factor=1):
ax.scatter(X[:, 0], X[:, 1], c=c, cmap='viridis',
s=50 * factor, alpha=0.3)
def draw_centers(ax, centers, factor=1, alpha=1.0):
ax.scatter(centers[:, 0], centers[:, 1],
c=np.arange(4), cmap='viridis', s=200 * factor,
alpha=alpha)
ax.scatter(centers[:, 0], centers[:, 1],
c='black', s=50 * factor, alpha=alpha)
def make_ax(fig, gs):
ax = fig.add_subplot(gs)
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
return ax
fig = plt.figure(figsize=(15, 4))
gs = plt.GridSpec(4, 15, left=0.02, right=0.98, bottom=0.05, top=0.95, wspace=0.2, hspace=0.2)
ax0 = make_ax(fig, gs[:4, :4])
ax0.text(0.98, 0.98, "Random Initialization", transform=ax0.transAxes,
ha='right', va='top', size=16)
draw_points(ax0, 'gray', factor=2)
draw_centers(ax0, centers, factor=2)
for i in range(3):
ax1 = make_ax(fig, gs[:2, 4 + 2 * i:6 + 2 * i])
ax2 = make_ax(fig, gs[2:, 5 + 2 * i:7 + 2 * i])
期望步骤
# E-step
y_pred = pairwise_distances_argmin(X, centers)
draw_points(ax1, y_pred)
draw_centers(ax1, centers)
最⼤化步骤
# M-step
new_centers = np.array([X[y_pred == i].mean(0) for i in range(4)])
draw_points(ax2, y_pred)
draw_centers(ax2, centers, alpha=0.3)
draw_centers(ax2, new_centers)
for i in range(4):
ax2.annotate('', new_centers[i], centers[i],
arrowprops=dict(arrowstyle='->', linewidth=1))
完成迭代
#
centers = new_centers
ax1.text(0.95, 0.95, "E-Step", transform=ax1.transAxes, ha='right', va='top', size=14)
ax2.text(0.95, 0.95, "M-Step", transform=ax2.transAxes, ha='right', va='top', size=14)
最终期望步骤
# Final E-step
y_pred = pairwise_distances_argmin(X, centers)
axf = make_ax(fig, gs[:4, -4:])
draw_points(axf, y_pred, factor=2)
draw_centers(axf, centers, factor=2)
axf.text(0.98, 0.98, "Final Clustering", transform=axf.transAxes,
ha='right', va='top', size=16)
fig.savefig('figures/05.11-expectation-maximization.png')
Interactive K-Means
交互式k均值
The following script uses IPython's interactive widgets to demonstrate the K-means algorithm interactively. Run this within
the IPython notebook to explore the expectation maximization algorithm for computing K Means.
下⾯脚本使⽤IPython的交互式组件展⽰k均值算法。在IPython notebook中运⾏它们来研究计算k均值是的期望最⼤化算法。
In [46]: %matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set() # seaborn
import numpy as np
⻛格
from ipywidgets import interact
from sklearn.metrics import pairwise_distances_argmin
from sklearn.datasets.samples_generator import make_blobs
def plot_kmeans_interactive(min_clusters=1, max_clusters=6):
X, y = make_blobs(n_samples=300, centers=4,
random_state=0, cluster_std=0.60)
def plot_points(X, labels, n_clusters):
plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='viridis',
vmin=0, vmax=n_clusters - 1);
def plot_centers(centers):
plt.scatter(centers[:, 0], centers[:, 1], marker='o',
c=np.arange(centers.shape[0]),
s=200, cmap='viridis')
plt.scatter(centers[:, 0], centers[:, 1], marker='o',
c='black', s=50)
def _kmeans_step(frame=0, n_clusters=4):
rng = np.random.RandomState(2)
labels = np.zeros(X.shape[0])
centers = rng.randn(n_clusters, 2)
nsteps = frame // 3
for i in range(nsteps + 1):
old_centers = centers
if i < nsteps or frame % 3 > 0:
labels = pairwise_distances_argmin(X, centers)
if i < nsteps or frame % 3 > 1:
centers = np.array([X[labels == j].mean(0)
for j in range(n_clusters)])
nans = np.isnan(centers)
centers[nans] = old_centers[nans]
绘制数据和聚类中⼼点
#
plot_points(X, labels, n_clusters)
plot_centers(old_centers)
在第三帧时更新聚类中⼼点
#
if frame % 3 == 2:
for i in range(n_clusters):
plt.annotate('', centers[i], old_centers[i],
arrowprops=dict(arrowstyle='->', linewidth=1))
plot_centers(centers)
plt.xlim(-4, 4)
plt.ylim(-2, 10)
if frame % 3 == 1:
plt.text(3.8, 9.5, "1. Reassign points to nearest centroid",
ha='right', va='top', size=14)
elif frame % 3 == 2:
plt.text(3.8, 9.5, "2. Update centroids to cluster means",
ha='right', va='top', size=14)
return interact(_kmeans_step, frame=[0, 50],
n_clusters=[min_clusters, max_clusters])
plot_kmeans_interactive();
Gaussian Mixture Models
⾼斯混合模型
Covariance Type
协⽅差类型
图表所在正⽂
In [47]: from sklearn.mixture import GaussianMixture as GMM
from matplotlib.patches import Ellipse
def draw_ellipse(position, covariance, ax=None, **kwargs):
"""Draw an ellipse with a given position and covariance"""
ax = ax or plt.gca()
将协⽅差转换到主坐标轴
#
if covariance.shape == (2, 2):
U, s, Vt = np.linalg.svd(covariance)
angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
width, height = 2 * np.sqrt(s)
else:
angle = 0
width, height = 2 * np.sqrt(covariance)
绘制椭圆
#
for nsig in range(1, 4):
ax.add_patch(Ellipse(position, nsig * width, nsig * height,
angle, **kwargs))
fig, ax = plt.subplots(1, 3, figsize=(14, 4))
fig.subplots_adjust(wspace=0.05)
rng = np.random.RandomState(5)
X = np.dot(rng.randn(500, 2), rng.randn(2, 2))
for i, cov_type in enumerate(['diag', 'spherical', 'full']):
model = GMM(1, covariance_type=cov_type).fit(X)
ax[i].axis('equal')
ax[i].scatter(X[:, 0], X[:, 1], alpha=0.5)
ax[i].set_xlim(-3, 3)
ax[i].set_title('covariance_type="{0}"'.format(cov_type),
size=14, family='monospace')
if(model.covariance_type == 'spherical'):
cov = np.eye(2)*model.covariances_
else:
cov = model.covariances_[0]
draw_ellipse(model.means_[0], cov, ax[i], alpha=0.2)
ax[i].xaxis.set_major_formatter(plt.NullFormatter())
ax[i].yaxis.set_major_formatter(plt.NullFormatter())
fig.savefig('figures/05.12-covariance-type.png')
<
更多机器学习资源 | ⽬录 |
Open in Colab
0
You can add this document to your study collection(s)
Sign in Available only to authorized usersYou can add this document to your saved list
Sign in Available only to authorized users(For complaints, use another form )