How can I make this PyTorch heatmap function faster and more efficient?

ProGamerGov

I have this function that creates a sort if heatmap for 2d tensors, but it's painfully slow when using larger tensor inputs. How can I speed it up and make it more efficient?

import torch
import numpy as np
import matplotlib.pyplot as plt


def heatmap(
    tensor: torch.Tensor,
) -> torch.Tensor:
    assert tensor.dim() == 2

    def color_tensor(x: torch.Tensor) -> torch.Tensor:
        if x < 0:
            x = -x
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.5725, 0.7725, 0.8706])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.5725, 0.7725, 0.8706]
                ) + x * torch.tensor([0.0196, 0.4431, 0.6902])
        else:
            if x < 0.5:
                x = x * 2
                return (1 - x) * torch.tensor(
                    [0.9686, 0.9686, 0.9686]
                ) + x * torch.tensor([0.9569, 0.6471, 0.5098])
            else:
                x = (x - 0.5) * 2
                return (1 - x) * torch.tensor(
                    [0.9569, 0.6471, 0.5098]
                ) + x * torch.tensor([0.7922, 0.0000, 0.1255])

    return torch.stack(
        [torch.stack([color_tensor(x) for x in t]) for t in tensor]
    ).permute(2, 0, 1)

x = torch.randn(3,3)
x = x / x.max()
x_out = heatmap(x)

x_out = (x_out.permute(1, 2, 0) * 255).numpy()
plt.imshow(x_out.astype(np.uint8))
plt.axis("off")
plt.show()

An example of the output:

enter image description here

armamut

You need to get rid of ifs and the for loop and make a vectorized function. To do that, you can use masks and calculate all in one. Here it is:


def heatmap(tensor: torch.Tensor) -> torch.Tensor:
    assert tensor.dim() == 2

    # We're expanding to create one more dimension, for mult. to work.
    xt = x.expand((3, x.shape[0], x.shape[1])).permute(1, 2, 0)

    # this part is the mask: (xt >= 0) * (xt < 0.5) ...
    # ... the rest is the original function translated
    color_tensor = (
        (xt >= 0) * (xt < 0.5) * ((1 - xt * 2) * torch.tensor([0.9686, 0.9686, 0.9686]) + xt * 2 * torch.tensor([0.9569, 0.6471, 0.5098]))
        +
        (xt >= 0) * (xt >= 0.5) * ((1 - (xt - 0.5) * 2) * torch.tensor([0.9569, 0.6471, 0.5098]) + (xt - 0.5) * 2 * torch.tensor([0.7922, 0.0000, 0.1255]))
        +
        (xt < 0) * (xt > -0.5) * ((1 - (-xt * 2)) * torch.tensor([0.9686, 0.9686, 0.9686]) + (-xt * 2) * torch.tensor([0.5725, 0.7725, 0.8706]))
        +
        (xt < 0) * (xt <= -0.5) * ((1 - (-xt - 0.5) * 2) * torch.tensor([0.5725, 0.7725, 0.8706]) + (-xt - 0.5) * 2 * torch.tensor([0.0196, 0.4431, 0.6902]))
    ).permute(2, 0, 1)
    
    return color_tensor

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

How can I optimize this VBA code to make it more efficient and faster?

Can I make this macro more efficient or faster?

How can i make my function run even more faster?

How can I make my IsItAHoliday function more efficient?

How can I make this PHP function more efficient at scale?

How can I make this R function more efficient?

How can I make this PyTorch tensor (B, C, H, W) tiling & blending code simpler and more efficient?

How can I make this loop more efficient?

How can I make this more efficient in Android?

Is there a way I can make this into a function or anything more efficient

How do I make the selection function more efficient?

How can I make a code more efficient and shorter?

How can I make this Python code more efficient

How can I make this C# code more efficient?

How can I make my pandas code more efficient?

How can I make my trie more efficient?

How can I make this search query more efficient?

How can I make a recursive search for longest node more efficient?

How can I make this more efficient path finding?

How can I make this more efficient? (Merging arrays in C)

How can I refactor this code snippet to make it more efficient?

How can I make my website more efficient?

How can i make this modification on Dijkstra Algorithm more efficient?

How can I make this algorithm more efficient for a puzzle?

How can I make large IN clauses more efficient in SQL Server?

How can I make this pl/sql cursor more efficient?

How can I make my VBA error handling more efficient

How can I use iteration to make this vbnet code more efficient?

How can I concatenate my code properly and make this more efficient?