Skip to content

Commit b1ea5e0

Browse files
committed
Fix generic plot and add unit test
1 parent 81cbdb5 commit b1ea5e0

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

lightning/main.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def set_host(self, host):
117117
self.host = host
118118
return self
119119

120-
def plot(self, type=None, **kwargs):
120+
def plot(self, data=None, type=None):
121121
"""
122122
Generic plotting function.
123123
@@ -128,11 +128,22 @@ def plot(self, type=None, **kwargs):
128128
129129
Most useful when providing data to custom visualizations, as opposed to the included plot types
130130
(e.g. lightning.scatter, lightning.line, etc.) which do automatic parsing and formatting.
131+
132+
Parameters
133+
----------
134+
data : dict
135+
Dictionary with data to plot
136+
137+
type : str
138+
Name of plot (e.g. 'line' or 'scatter')
131139
"""
132140

133141
from types.plots import Generic
142+
143+
if not hasattr(self, 'session'):
144+
self.create_session()
134145

135-
viz = Generic.baseplot(self.session, type=type, **kwargs)
146+
viz = Generic.baseplot(self.session, type, data)
136147
self.session.visualizations.append(viz)
137148
return viz
138149

lightning/types/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Generic(Base):
99

1010
@staticmethod
1111
def clean(data):
12-
return {'data': data}
12+
return data
1313

1414

1515
@viztype

test/test_client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def test_create_scatter(self):
4141
assert hasattr(viz, 'id')
4242

4343

44+
def test_create_generic(self):
45+
46+
series = random.randn(5,100)
47+
48+
viz = lightning.plot(data={"series": series}, type='line')
49+
viz = lightning.plot({"series": series}, 'line')
50+
51+
assert isinstance(viz, Visualization)
52+
assert hasattr(viz, 'id')
53+
54+
4455
def test_create_line(self):
4556

4657
series = random.randn(5,100)

0 commit comments

Comments
 (0)