So verwenden Sie die NumPy argmax () -Funktion in Python

In diesem Tutorial lernen Sie, wie Sie die NumPy-Funktion argmax() verwenden, um den Index des maximalen Elements in Arrays zu finden.

NumPy ist eine leistungsstarke Bibliothek für wissenschaftliches Rechnen in Python; Es bietet N-dimensionale Arrays, die leistungsfähiger sind als Python-Listen. Eine der häufigsten Operationen, die Sie beim Arbeiten mit NumPy-Arrays ausführen, besteht darin, den Maximalwert im Array zu finden. Manchmal möchten Sie jedoch den Index finden, bei dem der maximale Wert auftritt.

Die Funktion argmax() hilft Ihnen, den Index des Maximums sowohl in eindimensionalen als auch in mehrdimensionalen Arrays zu finden. Lassen Sie uns fortfahren, um zu lernen, wie es funktioniert.

So finden Sie den Index des maximalen Elements in einem NumPy-Array

Um diesem Tutorial folgen zu können, müssen Sie Python und NumPy installiert haben. Sie können mitcodieren, indem Sie eine Python-REPL starten oder ein Jupyter-Notebook starten.

Importieren wir zunächst NumPy unter dem üblichen Alias ​​np.

import numpy as np

Sie können die Funktion NumPy max() verwenden, um den Maximalwert in einem Array (optional entlang einer bestimmten Achse) zu erhalten.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

In diesem Fall gibt np.max(array_1) 10 zurück, was korrekt ist.

Angenommen, Sie möchten den Index finden, an dem der maximale Wert im Array auftritt. Sie können den folgenden zweistufigen Ansatz wählen:

  • Finden Sie das maximale Element.
  • Finden Sie den Index des maximalen Elements.
  • In array_1 tritt der Maximalwert von 10 bei Index 4 auf, nach der Null-Indizierung. Das erste Element hat den Index 0; das zweite Element befindet sich auf Index 1 und so weiter.

    Um den Index zu finden, bei dem das Maximum auftritt, können Sie die Funktion NumPy where() verwenden. np.where(condition) gibt ein Array aller Indizes zurück, bei denen die Bedingung wahr ist.

    Sie müssen auf das Array tippen und auf das Element am ersten Index zugreifen. Um herauszufinden, wo der Maximalwert auftritt, setzen wir die Bedingung auf array_1==10; Denken Sie daran, dass 10 der maximale Wert in array_1 ist.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Wir haben np.where() nur mit der Bedingung verwendet, aber dies ist nicht die empfohlene Methode zur Verwendung dieser Funktion.

    📑 Hinweis: NumPy where() Funktion:
    np.where(condition,x,y) gibt zurück:

    – Elemente von x, wenn die Bedingung wahr ist, und
    – Elemente von y, wenn die Bedingung falsch ist.

    Daher können wir durch Verketten der Funktionen np.max() und np.where() das maximale Element finden, gefolgt von dem Index, an dem es auftritt.

    Anstelle des obigen zweistufigen Prozesses können Sie die NumPy-Funktion argmax() verwenden, um den Index des größten Elements im Array abzurufen.

      So verwenden Sie iTunes Musik in PowerPoint-Präsentationen

    Syntax der NumPy argmax()-Funktion

    Die allgemeine Syntax zur Verwendung der Funktion NumPy argmax() lautet wie folgt:

    np.argmax(array,axis,out)
    # we've imported numpy under the alias np

    In der obigen Syntax:

    • array ist ein beliebiges gültiges NumPy-Array.
    • Achse ist ein optionaler Parameter. Wenn Sie mit mehrdimensionalen Arrays arbeiten, können Sie den Achsenparameter verwenden, um den Index des Maximums entlang einer bestimmten Achse zu finden.
    • out ist ein weiterer optionaler Parameter. Sie können den out-Parameter auf ein NumPy-Array setzen, um die Ausgabe der argmax()-Funktion zu speichern.

    Hinweis: Ab NumPy-Version 1.22.0 gibt es einen zusätzlichen keepdims-Parameter. Wenn wir den Achsenparameter im Funktionsaufruf argmax() angeben, wird das Array entlang dieser Achse verkleinert. Wenn Sie den Parameter keepdims jedoch auf True setzen, wird sichergestellt, dass die zurückgegebene Ausgabe die gleiche Form wie das Eingabearray hat.

    Verwenden von NumPy argmax(), um den Index des maximalen Elements zu finden

    #1. Lassen Sie uns die Funktion NumPy argmax() verwenden, um den Index des größten Elements in array_1 zu finden.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Die Funktion argmax() gibt 4 zurück, was richtig ist! ✅

    #2. Wenn wir array_1 so umdefinieren, dass 10 zweimal vorkommt, gibt die Funktion argmax() nur den Index des ersten Vorkommens zurück.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Für die restlichen Beispiele verwenden wir die Elemente von array_1, die wir in Beispiel 1 definiert haben.

    Verwenden von NumPy argmax() zum Finden des Index des maximalen Elements in einem 2D-Array

    Lassen Sie uns das NumPy-Array array_1 in ein zweidimensionales Array mit zwei Zeilen und vier Spalten umformen.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    Bei einem zweidimensionalen Array bezeichnet die Achse 0 die Zeilen und die Achse 1 die Spalten. NumPy-Arrays folgen der Nullindizierung. Die Indizes der Zeilen und Spalten für das NumPy-Array array_2 lauten also wie folgt:

    Rufen wir nun die Funktion argmax() für das zweidimensionale Array array_2 auf.

    print(np.argmax(array_2))
    
    # Output
    4

    Obwohl wir argmax() für das zweidimensionale Array aufgerufen haben, gibt es immer noch 4 zurück. Dies ist identisch mit der Ausgabe für das eindimensionale Array array_1 aus dem vorherigen Abschnitt.

    Warum passiert das? 🤔

    Dies liegt daran, dass wir keinen Wert für den Achsenparameter angegeben haben. Wenn dieser Achsenparameter nicht festgelegt ist, gibt die Funktion argmax() standardmäßig den Index des größten Elements entlang des reduzierten Arrays zurück.

    Was ist ein abgeflachtes Array? Wenn es ein N-dimensionales Array der Form d1 x d2 x … x dN gibt, wobei d1, d2 bis dN die Größen des Arrays entlang der N Dimensionen sind, dann ist das abgeflachte Array ein langes eindimensionales Array der Größe d1 * d2 * … * dN.

      3 Tools zum Erstellen und Einbetten von Grafiken oder Diagrammen online

    Um zu überprüfen, wie das abgeflachte Array für array_2 aussieht, können Sie die Methode flatten() aufrufen, wie unten gezeigt:

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Index des maximalen Elements entlang der Zeilen (Achse = 0)

    Lassen Sie uns fortfahren, um den Index des maximalen Elements entlang der Zeilen zu finden (Achse = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Diese Ausgabe kann etwas schwierig zu verstehen sein, aber wir werden verstehen, wie sie funktioniert.

    Wir haben den Achsenparameter auf Null gesetzt (Achse = 0), da wir den Index des größten Elements entlang der Zeilen finden möchten. Daher gibt die Funktion argmax() die Zeilennummer zurück, in der das maximale Element vorkommt – für jede der drei Spalten.

    Lassen Sie uns dies zum besseren Verständnis visualisieren.

    Aus dem obigen Diagramm und der Ausgabe von argmax() haben wir Folgendes:

    • Für die erste Spalte bei Index 0 tritt in der zweiten Zeile bei Index = 1 der Maximalwert 10 auf.
    • Für die zweite Spalte bei Index 1 tritt der Maximalwert 9 in der zweiten Zeile bei Index = 1 auf.
    • Für die dritte und vierte Spalte bei Index 2 und 3 treten die Maximalwerte 8 und 4 beide in der zweiten Zeile bei Index = 1 auf.

    Genau aus diesem Grund haben wir das Ausgabearray ([1, 1, 1, 1]), da das maximale Element entlang der Zeilen in der zweiten Zeile auftritt (für alle Spalten).

    Index des maximalen Elements entlang der Spalten (Achse = 1)

    Als Nächstes verwenden wir die Funktion argmax(), um den Index des maximalen Elements entlang der Spalten zu finden.

    Führen Sie das folgende Code-Snippet aus und beobachten Sie die Ausgabe.

    np.argmax(array_2,axis=1)
    array([2, 0])

    Kannst du die Ausgabe parsen?

    Wir haben Achse = 1 gesetzt, um den Index des größten Elements entlang der Spalten zu berechnen.

    Die Funktion argmax() gibt für jede Zeile die Spaltennummer zurück, in der der Maximalwert auftritt.

    Hier ist eine visuelle Erklärung:

    Aus dem obigen Diagramm und der Ausgabe von argmax() haben wir Folgendes:

    • Für die erste Zeile bei Index 0 tritt in der dritten Spalte bei Index = 2 der Maximalwert 7 auf.
    • Für die zweite Zeile bei Index 1 tritt in der ersten Spalte bei Index = 0 der Maximalwert 10 auf.

    Ich hoffe, Sie verstehen jetzt, was die Ausgabe, array([2, 0]) meint.

      Wie melde ich mich mit meinem Spotify-Konto bei Hulu an?

    Verwenden des optionalen out-Parameters in NumPy argmax()

    Sie können den optionalen Parameter out the in der Funktion NumPy argmax() verwenden, um die Ausgabe in einem NumPy-Array zu speichern.

    Lassen Sie uns ein Array aus Nullen initialisieren, um die Ausgabe des vorherigen argmax()-Funktionsaufrufs zu speichern – um den Index des Maximums entlang der Spalten zu finden (axis= 1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Betrachten wir nun das Beispiel, in dem der Index des maximalen Elements entlang der Spalten (Achse = 1) ermittelt wird, und setzen Sie out auf out_arr, das wir oben definiert haben.

    np.argmax(array_2,axis=1,out=out_arr)

    Wir sehen, dass der Python-Interpreter einen TypeError auslöst, da out_arr standardmäßig mit einem Array von Gleitkommazahlen initialisiert wurde.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Daher ist es wichtig, beim Festlegen des out-Parameters für das Ausgabearray sicherzustellen, dass das Ausgabearray die richtige Form und den richtigen Datentyp hat. Da Array-Indizes immer ganze Zahlen sind, sollten wir den dtype-Parameter auf int setzen, wenn wir das Ausgabe-Array definieren.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Wir können jetzt fortfahren und die Funktion argmax() sowohl mit den Parametern axis als auch out aufrufen, und dieses Mal läuft sie ohne Fehler.

    np.argmax(array_2,axis=1,out=out_arr)

    Auf die Ausgabe der Funktion argmax() kann nun im Array out_arr zugegriffen werden.

    print(out_arr)
    # Output
    [2 0]

    Fazit

    Ich hoffe, dieses Tutorial hat Ihnen geholfen, die NumPy-Funktion argmax() zu verstehen. Sie können die Codebeispiele in einem Jupyter-Notebook ausführen.

    Lassen Sie uns überprüfen, was wir gelernt haben.

    • Die Funktion NumPy argmax() gibt den Index des größten Elements in einem Array zurück. Wenn das maximale Element mehr als einmal in einem Array a vorkommt, gibt np.argmax(a) den Index des ersten Vorkommens des Elements zurück.
    • Wenn Sie mit mehrdimensionalen Arrays arbeiten, können Sie den optionalen Achsenparameter verwenden, um den Index des größten Elements entlang einer bestimmten Achse zu erhalten. Zum Beispiel in einem zweidimensionalen Array: Indem Sie Achse = 0 und Achse = 1 setzen, können Sie den Index des größten Elements entlang der Zeilen bzw. Spalten erhalten.
    • Wenn Sie den zurückgegebenen Wert in einem anderen Array speichern möchten, können Sie den optionalen out-Parameter auf das Ausgabe-Array setzen. Das Ausgabearray sollte jedoch eine kompatible Form und einen kompatiblen Datentyp haben.

    Sehen Sie sich als Nächstes die ausführliche Anleitung zu Python-Sets an.