Skip to content

Commit 7660716

Browse files
committed
correct bug jax and simplify code
1 parent b71511a commit 7660716

File tree

13 files changed

+464
-61
lines changed

13 files changed

+464
-61
lines changed

ot/sliced.py

Lines changed: 15 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,12 @@ def sliced_plans(
790790
else:
791791
n_proj = thetas.shape[0]
792792

793+
def dist(i, j):
794+
if metric == "sqeuclidean":
795+
return nx.sum((X[i] - Y[j]) ** 2, axis=1)
796+
else:
797+
return nx.sum(nx.abs(X[i] - Y[j]) ** p, axis=1) ** (1 / p)
798+
793799
# project on each theta: (n or m, d) -> (n or m, n_proj)
794800
X_theta = X @ thetas.T # shape (n, n_proj)
795801
Y_theta = Y @ thetas.T # shape (m, n_proj)
@@ -798,33 +804,15 @@ def sliced_plans(
798804
# sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
799805
sigma = nx.argsort(X_theta, axis=0) # (n, n_proj)
800806
tau = nx.argsort(Y_theta, axis=0) # (m, n_proj)
801-
if metric in ("minkowski", "euclidean", "cityblock"):
802-
costs = [
803-
nx.sum(
804-
(
805-
(nx.sum(nx.abs(X[sigma[:, k]] - Y[tau[:, k]]) ** p, axis=1))
806-
** (1 / p)
807-
)
808-
/ n
809-
)
810-
for k in range(n_proj)
811-
]
812-
else: # metric = "sqeuclidean":
813-
costs = [
814-
nx.sum((nx.sum((X[sigma[:, k]] - Y[tau[:, k]]) ** 2, axis=1)) / n)
815-
for k in range(n_proj)
816-
]
807+
808+
costs = [nx.sum(dist(sigma[:, k], tau[:, k]) / n) for k in range(n_proj)]
817809

818810
a = nx.ones(n) / n
819811
plan = [
820812
nx.coo_matrix(a, sigma[:, k], tau[:, k], shape=(n, m), type_as=a)
821813
for k in range(n_proj)
822814
]
823815

824-
if not dense and str(nx) == "jax":
825-
warnings.warn("JAX does not support sparse matrices, converting to dense")
826-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
827-
828816
else: # we compute plans
829817
_, plan = wasserstein_1d(
830818
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
@@ -835,56 +823,23 @@ def sliced_plans(
835823
warnings.warn(
836824
"JAX does not support sparse matrices, converting to dense"
837825
)
838-
839826
plan = [nx.todense(plan[k]) for k in range(n_proj)]
840-
827+
idx_non_zeros = [np.nonzero(plan[k]) for k in range(n_proj)]
841828
costs = [
842829
nx.sum(
843-
(
844-
(
845-
nx.sum(
846-
nx.abs(
847-
X[np.nonzero(plan[k])[0]]
848-
- Y[np.nonzero(plan[k])[1]]
849-
)
850-
** p,
851-
axis=1,
852-
)
853-
)
854-
** (1 / p)
855-
)
856-
* plan[np.nonzero(plan[k])]
830+
dist(idx_non_zeros[k][0], idx_non_zeros[k][1])
831+
* plan[k][idx_non_zeros[k][0], idx_non_zeros[k][1]]
857832
)
858833
for k in range(n_proj)
859834
]
860-
861835
else:
862836
if str(nx) == "tensorflow": # tf does not support multiple indexing
863837
plan = [plan[k].tocsr().tocoo() for k in range(n_proj)]
864838

865-
if metric in ("minkowski", "euclidean", "cityblock"):
866-
costs = [
867-
nx.sum(
868-
(
869-
(
870-
nx.sum(
871-
nx.abs(X[plan[k].row] - Y[plan[k].col]) ** p, axis=1
872-
)
873-
)
874-
** (1 / p)
875-
)
876-
* plan[k].data
877-
)
878-
for k in range(n_proj)
879-
]
880-
else: # metric == "sqeuclidean"
881-
costs = [
882-
nx.sum(
883-
(nx.sum((X[plan[k].row] - Y[plan[k].col]) ** 2, axis=1))
884-
* plan[k].data
885-
)
886-
for k in range(n_proj)
887-
]
839+
costs = [
840+
nx.sum(dist(plan[k].row, plan[k].col) * plan[k].data)
841+
for k in range(n_proj)
842+
]
888843

889844
if dense and not str(nx) == "jax":
890845
plan = [nx.todense(plan[k]) for k in range(n_proj)]

test-pot/bin/Activate.ps1

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
<#
2+
.Synopsis
3+
Activate a Python virtual environment for the current PowerShell session.
4+
5+
.Description
6+
Pushes the python executable for a virtual environment to the front of the
7+
$Env:PATH environment variable and sets the prompt to signify that you are
8+
in a Python virtual environment. Makes use of the command line switches as
9+
well as the `pyvenv.cfg` file values present in the virtual environment.
10+
11+
.Parameter VenvDir
12+
Path to the directory that contains the virtual environment to activate. The
13+
default value for this is the parent of the directory that the Activate.ps1
14+
script is located within.
15+
16+
.Parameter Prompt
17+
The prompt prefix to display when this virtual environment is activated. By
18+
default, this prompt is the name of the virtual environment folder (VenvDir)
19+
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
20+
21+
.Example
22+
Activate.ps1
23+
Activates the Python virtual environment that contains the Activate.ps1 script.
24+
25+
.Example
26+
Activate.ps1 -Verbose
27+
Activates the Python virtual environment that contains the Activate.ps1 script,
28+
and shows extra information about the activation as it executes.
29+
30+
.Example
31+
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
32+
Activates the Python virtual environment located in the specified location.
33+
34+
.Example
35+
Activate.ps1 -Prompt "MyPython"
36+
Activates the Python virtual environment that contains the Activate.ps1 script,
37+
and prefixes the current prompt with the specified string (surrounded in
38+
parentheses) while the virtual environment is active.
39+
40+
.Notes
41+
On Windows, it may be required to enable this Activate.ps1 script by setting the
42+
execution policy for the user. You can do this by issuing the following PowerShell
43+
command:
44+
45+
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
46+
47+
For more information on Execution Policies:
48+
https://go.microsoft.com/fwlink/?LinkID=135170
49+
50+
#>
51+
Param(
52+
[Parameter(Mandatory = $false)]
53+
[String]
54+
$VenvDir,
55+
[Parameter(Mandatory = $false)]
56+
[String]
57+
$Prompt
58+
)
59+
60+
<# Function declarations --------------------------------------------------- #>
61+
62+
<#
63+
.Synopsis
64+
Remove all shell session elements added by the Activate script, including the
65+
addition of the virtual environment's Python executable from the beginning of
66+
the PATH variable.
67+
68+
.Parameter NonDestructive
69+
If present, do not remove this function from the global namespace for the
70+
session.
71+
72+
#>
73+
function global:deactivate ([switch]$NonDestructive) {
74+
# Revert to original values
75+
76+
# The prior prompt:
77+
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
78+
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
79+
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
80+
}
81+
82+
# The prior PYTHONHOME:
83+
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
84+
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
85+
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
86+
}
87+
88+
# The prior PATH:
89+
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
90+
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
91+
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
92+
}
93+
94+
# Just remove the VIRTUAL_ENV altogether:
95+
if (Test-Path -Path Env:VIRTUAL_ENV) {
96+
Remove-Item -Path env:VIRTUAL_ENV
97+
}
98+
99+
# Just remove VIRTUAL_ENV_PROMPT altogether.
100+
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
101+
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
102+
}
103+
104+
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
105+
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
106+
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
107+
}
108+
109+
# Leave deactivate function in the global namespace if requested:
110+
if (-not $NonDestructive) {
111+
Remove-Item -Path function:deactivate
112+
}
113+
}
114+
115+
<#
116+
.Description
117+
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
118+
given folder, and returns them in a map.
119+
120+
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
121+
two strings separated by `=` (with any amount of whitespace surrounding the =)
122+
then it is considered a `key = value` line. The left hand string is the key,
123+
the right hand is the value.
124+
125+
If the value starts with a `'` or a `"` then the first and last character is
126+
stripped from the value before being captured.
127+
128+
.Parameter ConfigDir
129+
Path to the directory that contains the `pyvenv.cfg` file.
130+
#>
131+
function Get-PyVenvConfig(
132+
[String]
133+
$ConfigDir
134+
) {
135+
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
136+
137+
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
138+
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
139+
140+
# An empty map will be returned if no config file is found.
141+
$pyvenvConfig = @{ }
142+
143+
if ($pyvenvConfigPath) {
144+
145+
Write-Verbose "File exists, parse `key = value` lines"
146+
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
147+
148+
$pyvenvConfigContent | ForEach-Object {
149+
$keyval = $PSItem -split "\s*=\s*", 2
150+
if ($keyval[0] -and $keyval[1]) {
151+
$val = $keyval[1]
152+
153+
# Remove extraneous quotations around a string value.
154+
if ("'""".Contains($val.Substring(0, 1))) {
155+
$val = $val.Substring(1, $val.Length - 2)
156+
}
157+
158+
$pyvenvConfig[$keyval[0]] = $val
159+
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
160+
}
161+
}
162+
}
163+
return $pyvenvConfig
164+
}
165+
166+
167+
<# Begin Activate script --------------------------------------------------- #>
168+
169+
# Determine the containing directory of this script
170+
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
171+
$VenvExecDir = Get-Item -Path $VenvExecPath
172+
173+
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
174+
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
175+
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
176+
177+
# Set values required in priority: CmdLine, ConfigFile, Default
178+
# First, get the location of the virtual environment, it might not be
179+
# VenvExecDir if specified on the command line.
180+
if ($VenvDir) {
181+
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
182+
}
183+
else {
184+
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
185+
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
186+
Write-Verbose "VenvDir=$VenvDir"
187+
}
188+
189+
# Next, read the `pyvenv.cfg` file to determine any required value such
190+
# as `prompt`.
191+
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
192+
193+
# Next, set the prompt from the command line, or the config file, or
194+
# just use the name of the virtual environment folder.
195+
if ($Prompt) {
196+
Write-Verbose "Prompt specified as argument, using '$Prompt'"
197+
}
198+
else {
199+
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
200+
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
201+
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
202+
$Prompt = $pyvenvCfg['prompt'];
203+
}
204+
else {
205+
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
206+
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
207+
$Prompt = Split-Path -Path $venvDir -Leaf
208+
}
209+
}
210+
211+
Write-Verbose "Prompt = '$Prompt'"
212+
Write-Verbose "VenvDir='$VenvDir'"
213+
214+
# Deactivate any currently active virtual environment, but leave the
215+
# deactivate function in place.
216+
deactivate -nondestructive
217+
218+
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
219+
# that there is an activated venv.
220+
$env:VIRTUAL_ENV = $VenvDir
221+
222+
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
223+
224+
Write-Verbose "Setting prompt to '$Prompt'"
225+
226+
# Set the prompt to include the env name
227+
# Make sure _OLD_VIRTUAL_PROMPT is global
228+
function global:_OLD_VIRTUAL_PROMPT { "" }
229+
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
230+
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
231+
232+
function global:prompt {
233+
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
234+
_OLD_VIRTUAL_PROMPT
235+
}
236+
$env:VIRTUAL_ENV_PROMPT = $Prompt
237+
}
238+
239+
# Clear PYTHONHOME
240+
if (Test-Path -Path Env:PYTHONHOME) {
241+
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
242+
Remove-Item -Path Env:PYTHONHOME
243+
}
244+
245+
# Add the venv to the PATH
246+
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
247+
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"

0 commit comments

Comments
 (0)